Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader, default_collate | |
| from torchvision.transforms import ToTensor, Normalize, Compose, Resize | |
| from torchvision.transforms.functional import to_tensor | |
| from pytorch_lightning import LightningDataModule | |
| from einops import rearrange | |
| def read_camera_matrix_single(json_file): | |
| # for gobjaverse | |
| with open(json_file, "r", encoding="utf8") as reader: | |
| json_content = json.load(reader) | |
| # negative sign for opencv to opengl | |
| camera_matrix = torch.zeros(3, 4) | |
| camera_matrix[:3, 0] = torch.tensor(json_content["x"]) | |
| camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) | |
| camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) | |
| camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) | |
| """ | |
| camera_matrix = np.eye(4) | |
| camera_matrix[:3, 0] = np.array(json_content['x']) | |
| camera_matrix[:3, 1] = np.array(json_content['y']) | |
| camera_matrix[:3, 2] = np.array(json_content['z']) | |
| camera_matrix[:3, 3] = np.array(json_content['origin']) | |
| # print(camera_matrix) | |
| """ | |
| return camera_matrix | |
| def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0): | |
| with open(json_file, "r", encoding="utf8") as reader: | |
| json_content = json.load(reader) | |
| h = int(h * scale) | |
| w = int(w * scale) | |
| y_fov = json_content["y_fov"] | |
| x_fov = json_content["x_fov"] | |
| fy = h / 2 / np.tan(y_fov / 2) | |
| fx = w / 2 / np.tan(x_fov / 2) | |
| cx = w // 2 | |
| cy = h // 2 | |
| intrinsics = torch.tensor( | |
| [ | |
| [fx, fy], | |
| [cx, cy], | |
| [w, h], | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| return intrinsics | |
| def compose_extrinsic_RT(RT: torch.Tensor): | |
| """ | |
| Compose the standard form extrinsic matrix from RT. | |
| Batched I/O. | |
| """ | |
| return torch.cat( | |
| [ | |
| RT, | |
| torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat( | |
| RT.shape[0], 1, 1 | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| def get_normalized_camera_intrinsics(intrinsics: torch.Tensor): | |
| """ | |
| intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] | |
| Return batched fx, fy, cx, cy | |
| """ | |
| fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] | |
| cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] | |
| width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] | |
| fx, fy = fx / width, fy / height | |
| cx, cy = cx / width, cy / height | |
| return fx, fy, cx, cy | |
| def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor): | |
| """ | |
| RT: (N, 3, 4) | |
| intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] | |
| """ | |
| E = compose_extrinsic_RT(RT) | |
| fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics) | |
| I = torch.stack( | |
| [ | |
| torch.stack([fx, torch.zeros_like(fx), cx], dim=-1), | |
| torch.stack([torch.zeros_like(fy), fy, cy], dim=-1), | |
| torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1), | |
| ], | |
| dim=1, | |
| ) | |
| return torch.cat( | |
| [ | |
| E.reshape(-1, 16), | |
| I.reshape(-1, 9), | |
| ], | |
| dim=-1, | |
| ) | |
| def calc_elevation(c2w): | |
| ## works for single or batched c2w | |
| ## assume world up is (0, 0, 1) | |
| pos = c2w[..., :3, 3] | |
| return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False)) | |
| def read_camera_matrix_single(json_file): | |
| with open(json_file, "r", encoding="utf8") as reader: | |
| json_content = json.load(reader) | |
| # negative sign for opencv to opengl | |
| # camera_matrix = np.zeros([3, 4]) | |
| # camera_matrix[:3, 0] = np.array(json_content["x"]) | |
| # camera_matrix[:3, 1] = -np.array(json_content["y"]) | |
| # camera_matrix[:3, 2] = -np.array(json_content["z"]) | |
| # camera_matrix[:3, 3] = np.array(json_content["origin"]) | |
| camera_matrix = torch.zeros([3, 4]) | |
| camera_matrix[:3, 0] = torch.tensor(json_content["x"]) | |
| camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) | |
| camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) | |
| camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) | |
| """ | |
| camera_matrix = np.eye(4) | |
| camera_matrix[:3, 0] = np.array(json_content['x']) | |
| camera_matrix[:3, 1] = np.array(json_content['y']) | |
| camera_matrix[:3, 2] = np.array(json_content['z']) | |
| camera_matrix[:3, 3] = np.array(json_content['origin']) | |
| # print(camera_matrix) | |
| """ | |
| return camera_matrix | |
| def blend_white_bg(image): | |
| new_image = Image.new("RGB", image.size, (255, 255, 255)) | |
| new_image.paste(image, mask=image.split()[3]) | |
| return new_image | |
| def flatten_for_video(input): | |
| return input.flatten() | |
| FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"] | |
| def video_collate_fn(batch: list[dict], *args, **kwargs): | |
| out = {} | |
| for key in batch[0].keys(): | |
| if key in FLATTEN_FIELDS: | |
| out[key] = default_collate([item[key] for item in batch]) | |
| out[key] = flatten_for_video(out[key]) | |
| elif key == "num_video_frames": | |
| out[key] = batch[0][key] | |
| elif key in ["frames", "latents", "rgb"]: | |
| out[key] = default_collate([item[key] for item in batch]) | |
| out[key] = rearrange(out[key], "b t c h w -> (b t) c h w") | |
| else: | |
| out[key] = default_collate([item[key] for item in batch]) | |
| if "pixelnerf_input" in out: | |
| out["pixelnerf_input"]["rgb"] = rearrange( | |
| out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w" | |
| ) | |
| return out | |
| class GObjaverse(Dataset): | |
| def __init__( | |
| self, | |
| root_dir, | |
| split="train", | |
| transform=None, | |
| random_front=False, | |
| max_item=None, | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| condition_on_elevation=False, | |
| fps_id=0.0, | |
| motion_bucket_id=300.0, | |
| use_latents=False, | |
| load_caps=False, | |
| front_view_selection="random", | |
| load_pixelnerf=False, | |
| debug_base_idx=None, | |
| scale_pose: bool = False, | |
| max_n_cond: int = 1, | |
| **unused_kwargs, | |
| ): | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.random_front = random_front | |
| self.transform = transform | |
| self.use_latents = use_latents | |
| self.ids = json.load(open(self.root_dir / "valid_uids.json", "r")) | |
| self.n_views = 24 | |
| self.load_caps = load_caps | |
| if self.load_caps: | |
| self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r")) | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| self.condition_on_elevation = condition_on_elevation | |
| self.fps_id = fps_id | |
| self.motion_bucket_id = motion_bucket_id | |
| self.load_pixelnerf = load_pixelnerf | |
| self.scale_pose = scale_pose | |
| self.max_n_cond = max_n_cond | |
| if self.use_latents: | |
| self.latents_dir = self.root_dir / "latents256" | |
| self.clip_dir = self.root_dir / "clip_emb256" | |
| self.front_view_selection = front_view_selection | |
| if self.front_view_selection == "random": | |
| pass | |
| elif self.front_view_selection == "fixed": | |
| pass | |
| elif self.front_view_selection.startswith("clip_score"): | |
| self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt") | |
| self.ids = list(self.clip_scores.keys()) | |
| else: | |
| raise ValueError( | |
| f"Unknown front view selection method {self.front_view_selection}" | |
| ) | |
| if max_item is not None: | |
| self.ids = self.ids[:max_item] | |
| ## debug | |
| self.ids = self.ids * 10000 | |
| if debug_base_idx is not None: | |
| print(f"debug mode with base idx: {debug_base_idx}") | |
| self.debug_base_idx = debug_base_idx | |
| def __getitem__(self, idx: int): | |
| if hasattr(self, "debug_base_idx"): | |
| idx = (idx + self.debug_base_idx) % len(self.ids) | |
| data = {} | |
| idx_list = np.arange(self.n_views) | |
| # if self.random_front: | |
| # roll_idx = np.random.randint(self.n_views) | |
| # idx_list = np.roll(idx_list, roll_idx) | |
| if self.front_view_selection == "random": | |
| roll_idx = np.random.randint(self.n_views) | |
| idx_list = np.roll(idx_list, roll_idx) | |
| elif self.front_view_selection == "fixed": | |
| pass | |
| elif self.front_view_selection == "clip_score_softmax": | |
| this_clip_score = ( | |
| F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() | |
| ) | |
| roll_idx = np.random.choice(idx_list, p=this_clip_score) | |
| idx_list = np.roll(idx_list, roll_idx) | |
| elif self.front_view_selection == "clip_score_max": | |
| this_clip_score = ( | |
| F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() | |
| ) | |
| roll_idx = np.argmax(this_clip_score) | |
| idx_list = np.roll(idx_list, roll_idx) | |
| frames = [] | |
| if not self.use_latents: | |
| try: | |
| for view_idx in idx_list: | |
| frame = Image.open( | |
| self.root_dir | |
| / "gobjaverse" | |
| / self.ids[idx] | |
| / f"{view_idx:05d}/{view_idx:05d}.png" | |
| ) | |
| frames.append(self.transform(frame)) | |
| except: | |
| idx = 0 | |
| frames = [] | |
| for view_idx in idx_list: | |
| frame = Image.open( | |
| self.root_dir | |
| / "gobjaverse" | |
| / self.ids[idx] | |
| / f"{view_idx:05d}/{view_idx:05d}.png" | |
| ) | |
| frames.append(self.transform(frame)) | |
| # a workaround for some bugs in gobjaverse | |
| # use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results | |
| frames = torch.stack(frames, dim=0) | |
| cond = frames[0] | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| data.update( | |
| { | |
| "frames": frames, | |
| "cond_frames_without_noise": cond, | |
| "cond_aug": torch.as_tensor([cond_aug] * self.n_views), | |
| "cond_frames": cond + cond_aug * torch.randn_like(cond), | |
| "fps_id": torch.as_tensor([self.fps_id] * self.n_views), | |
| "motion_bucket_id": torch.as_tensor( | |
| [self.motion_bucket_id] * self.n_views | |
| ), | |
| "num_video_frames": 24, | |
| "image_only_indicator": torch.as_tensor([0.0] * self.n_views), | |
| } | |
| ) | |
| else: | |
| latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list] | |
| clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0] | |
| cond = latents[0] | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| data.update( | |
| { | |
| "latents": latents, | |
| "cond_frames_without_noise": clip_emb, | |
| "cond_aug": torch.as_tensor([cond_aug] * self.n_views), | |
| "cond_frames": cond + cond_aug * torch.randn_like(cond), | |
| "fps_id": torch.as_tensor([self.fps_id] * self.n_views), | |
| "motion_bucket_id": torch.as_tensor( | |
| [self.motion_bucket_id] * self.n_views | |
| ), | |
| "num_video_frames": 24, | |
| "image_only_indicator": torch.as_tensor([0.0] * self.n_views), | |
| } | |
| ) | |
| if self.condition_on_elevation: | |
| sample_c2w = read_camera_matrix_single( | |
| self.root_dir / self.ids[idx] / f"00000/00000.json" | |
| ) | |
| elevation = calc_elevation(sample_c2w) | |
| data["elevation"] = torch.as_tensor([elevation] * self.n_views) | |
| if self.load_pixelnerf: | |
| assert "frames" in data, f"pixelnerf cannot work with latents only mode" | |
| data["pixelnerf_input"] = {} | |
| RTs = [] | |
| intrinsics = [] | |
| for view_idx in idx_list: | |
| meta = ( | |
| self.root_dir | |
| / "gobjaverse" | |
| / self.ids[idx] | |
| / f"{view_idx:05d}/{view_idx:05d}.json" | |
| ) | |
| RTs.append(read_camera_matrix_single(meta)[:3]) | |
| intrinsics.append(read_camera_instrinsics_single(meta, 256, 256)) | |
| RTs = torch.stack(RTs, dim=0) | |
| intrinsics = torch.stack(intrinsics, dim=0) | |
| cameras = build_camera_standard(RTs, intrinsics) | |
| data["pixelnerf_input"]["cameras"] = cameras | |
| downsampled = [] | |
| for view_idx in idx_list: | |
| frame = Image.open( | |
| self.root_dir | |
| / "gobjaverse" | |
| / self.ids[idx] | |
| / f"{view_idx:05d}/{view_idx:05d}.png" | |
| ).resize((32, 32)) | |
| downsampled.append(to_tensor(blend_white_bg(frame))) | |
| data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0) | |
| data["pixelnerf_input"]["frames"] = data["frames"] | |
| if self.scale_pose: | |
| c2ws = cameras[..., :16].reshape(-1, 4, 4) | |
| center = c2ws[:, :3, 3].mean(0) | |
| radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() | |
| scale = 1.5 / radius | |
| c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale | |
| cameras[..., :16] = c2ws.reshape(-1, 16) | |
| if self.load_caps: | |
| data["caption"] = self.caps[self.ids[idx]] | |
| data["ids"] = self.ids[idx] | |
| return data | |
| def __len__(self): | |
| return len(self.ids) | |
| def collate_fn(self, batch): | |
| if self.max_n_cond > 1: | |
| n_cond = np.random.randint(1, self.max_n_cond + 1) | |
| if n_cond > 1: | |
| for b in batch: | |
| source_index = [0] + np.random.choice( | |
| np.arange(1, self.n_views), | |
| self.max_n_cond - 1, | |
| replace=False, | |
| ).tolist() | |
| b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) | |
| b["pixelnerf_input"]["n_cond"] = n_cond | |
| b["pixelnerf_input"]["source_images"] = b["frames"][source_index] | |
| b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ | |
| "cameras" | |
| ][source_index] | |
| return video_collate_fn(batch) | |
| class ObjaverseSpiral(Dataset): | |
| def __init__( | |
| self, | |
| root_dir, | |
| split="train", | |
| transform=None, | |
| random_front=False, | |
| max_item=None, | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| condition_on_elevation=False, | |
| **unused_kwargs, | |
| ): | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.random_front = random_front | |
| self.transform = transform | |
| self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r")) | |
| self.n_views = 24 | |
| valid_ids = [] | |
| for idx in self.ids: | |
| if (self.root_dir / idx).exists(): | |
| valid_ids.append(idx) | |
| self.ids = valid_ids | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| self.condition_on_elevation = condition_on_elevation | |
| if max_item is not None: | |
| self.ids = self.ids[:max_item] | |
| ## debug | |
| self.ids = self.ids * 10000 | |
| def __getitem__(self, idx: int): | |
| frames = [] | |
| idx_list = np.arange(self.n_views) | |
| if self.random_front: | |
| roll_idx = np.random.randint(self.n_views) | |
| idx_list = np.roll(idx_list, roll_idx) | |
| for view_idx in idx_list: | |
| frame = Image.open( | |
| self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" | |
| ) | |
| frames.append(self.transform(frame)) | |
| # data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W] | |
| frames = torch.stack(frames, dim=0) | |
| cond = frames[0] | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| data = { | |
| "frames": frames, | |
| "cond_frames_without_noise": cond, | |
| "cond_aug": torch.as_tensor([cond_aug] * self.n_views), | |
| "cond_frames": cond + cond_aug * torch.randn_like(cond), | |
| "fps_id": torch.as_tensor([1.0] * self.n_views), | |
| "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), | |
| "num_video_frames": 24, | |
| "image_only_indicator": torch.as_tensor([0.0] * self.n_views), | |
| } | |
| if self.condition_on_elevation: | |
| sample_c2w = read_camera_matrix_single( | |
| self.root_dir / self.ids[idx] / f"00000/00000.json" | |
| ) | |
| elevation = calc_elevation(sample_c2w) | |
| data["elevation"] = torch.as_tensor([elevation] * self.n_views) | |
| return data | |
| def __len__(self): | |
| return len(self.ids) | |
| class ObjaverseLVISSpiral(Dataset): | |
| def __init__( | |
| self, | |
| root_dir, | |
| split="train", | |
| transform=None, | |
| random_front=False, | |
| max_item=None, | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| condition_on_elevation=False, | |
| use_precomputed_latents=False, | |
| **unused_kwargs, | |
| ): | |
| print("Using LVIS subset") | |
| self.root_dir = Path(root_dir) | |
| self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") | |
| self.split = split | |
| self.random_front = random_front | |
| self.transform = transform | |
| self.use_precomputed_latents = use_precomputed_latents | |
| self.ids = json.load(open("./assets/lvis_uids.json", "r")) | |
| self.n_views = 18 | |
| valid_ids = [] | |
| for idx in self.ids: | |
| if (self.root_dir / idx).exists(): | |
| valid_ids.append(idx) | |
| self.ids = valid_ids | |
| print("=" * 30) | |
| print("Number of valid ids: ", len(self.ids)) | |
| print("=" * 30) | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| self.condition_on_elevation = condition_on_elevation | |
| if max_item is not None: | |
| self.ids = self.ids[:max_item] | |
| ## debug | |
| self.ids = self.ids * 10000 | |
| def __getitem__(self, idx: int): | |
| frames = [] | |
| idx_list = np.arange(self.n_views) | |
| if self.random_front: | |
| roll_idx = np.random.randint(self.n_views) | |
| idx_list = np.roll(idx_list, roll_idx) | |
| for view_idx in idx_list: | |
| frame = Image.open( | |
| self.root_dir | |
| / self.ids[idx] | |
| / "elevations_0" | |
| / f"colors_{view_idx * 2}.png" | |
| ) | |
| frames.append(self.transform(frame)) | |
| frames = torch.stack(frames, dim=0) | |
| cond = frames[0] | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| data = { | |
| "frames": frames, | |
| "cond_frames_without_noise": cond, | |
| "cond_aug": torch.as_tensor([cond_aug] * self.n_views), | |
| "cond_frames": cond + cond_aug * torch.randn_like(cond), | |
| "fps_id": torch.as_tensor([0.0] * self.n_views), | |
| "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), | |
| "num_video_frames": self.n_views, | |
| "image_only_indicator": torch.as_tensor([0.0] * self.n_views), | |
| } | |
| if self.use_precomputed_latents: | |
| data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") | |
| if self.condition_on_elevation: | |
| # sample_c2w = read_camera_matrix_single( | |
| # self.root_dir / self.ids[idx] / f"00000/00000.json" | |
| # ) | |
| # elevation = calc_elevation(sample_c2w) | |
| # data["elevation"] = torch.as_tensor([elevation] * self.n_views) | |
| assert False, "currently assumes elevation 0" | |
| return data | |
| def __len__(self): | |
| return len(self.ids) | |
| class ObjaverseALLSpiral(ObjaverseLVISSpiral): | |
| def __init__( | |
| self, | |
| root_dir, | |
| split="train", | |
| transform=None, | |
| random_front=False, | |
| max_item=None, | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| condition_on_elevation=False, | |
| use_precomputed_latents=False, | |
| **unused_kwargs, | |
| ): | |
| print("Using ALL objects in Objaverse") | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.random_front = random_front | |
| self.transform = transform | |
| self.use_precomputed_latents = use_precomputed_latents | |
| self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") | |
| self.ids = json.load(open("./assets/all_ids.json", "r")) | |
| self.n_views = 18 | |
| valid_ids = [] | |
| for idx in self.ids: | |
| if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): | |
| valid_ids.append(idx) | |
| self.ids = valid_ids | |
| print("=" * 30) | |
| print("Number of valid ids: ", len(self.ids)) | |
| print("=" * 30) | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| self.condition_on_elevation = condition_on_elevation | |
| if max_item is not None: | |
| self.ids = self.ids[:max_item] | |
| ## debug | |
| self.ids = self.ids * 10000 | |
| class ObjaverseWithPose(Dataset): | |
| def __init__( | |
| self, | |
| root_dir, | |
| split="train", | |
| transform=None, | |
| random_front=False, | |
| max_item=None, | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| condition_on_elevation=False, | |
| use_precomputed_latents=False, | |
| **unused_kwargs, | |
| ): | |
| print("Using Objaverse with poses") | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.random_front = random_front | |
| self.transform = transform | |
| self.use_precomputed_latents = use_precomputed_latents | |
| self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") | |
| self.ids = json.load(open("./assets/all_ids.json", "r")) | |
| self.n_views = 18 | |
| valid_ids = [] | |
| for idx in self.ids: | |
| if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): | |
| valid_ids.append(idx) | |
| self.ids = valid_ids | |
| print("=" * 30) | |
| print("Number of valid ids: ", len(self.ids)) | |
| print("=" * 30) | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| self.condition_on_elevation = condition_on_elevation | |
| def __getitem__(self, idx: int): | |
| frames = [] | |
| idx_list = np.arange(self.n_views) | |
| if self.random_front: | |
| roll_idx = np.random.randint(self.n_views) | |
| idx_list = np.roll(idx_list, roll_idx) | |
| for view_idx in idx_list: | |
| frame = Image.open( | |
| self.root_dir | |
| / self.ids[idx] | |
| / "elevations_0" | |
| / f"colors_{view_idx * 2}.png" | |
| ) | |
| frames.append(self.transform(frame)) | |
| frames = torch.stack(frames, dim=0) | |
| cond = frames[0] | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| data = { | |
| "frames": frames, | |
| "cond_frames_without_noise": cond, | |
| "cond_aug": torch.as_tensor([cond_aug] * self.n_views), | |
| "cond_frames": cond + cond_aug * torch.randn_like(cond), | |
| "fps_id": torch.as_tensor([0.0] * self.n_views), | |
| "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), | |
| "num_video_frames": self.n_views, | |
| "image_only_indicator": torch.as_tensor([0.0] * self.n_views), | |
| } | |
| if self.use_precomputed_latents: | |
| data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") | |
| if self.condition_on_elevation: | |
| assert False, "currently assumes elevation 0" | |
| return data | |
| class LatentObjaverse(Dataset): | |
| def __init__( | |
| self, | |
| root_dir, | |
| split="train", | |
| random_front=False, | |
| subset="lvis", | |
| fps_id=1.0, | |
| motion_bucket_id=300.0, | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| **unused_kwargs, | |
| ): | |
| self.root_dir = Path(root_dir) | |
| self.split = split | |
| self.random_front = random_front | |
| self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r")) | |
| self.clip_emb_dir = self.root_dir / ".." / "clip_emb512" | |
| self.n_views = 18 | |
| self.fps_id = fps_id | |
| self.motion_bucket_id = motion_bucket_id | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| if self.random_front: | |
| print("Using a random view as front view") | |
| valid_ids = [] | |
| for idx in self.ids: | |
| if (self.root_dir / f"{idx}.pt").exists() and ( | |
| self.clip_emb_dir / f"{idx}.pt" | |
| ).exists(): | |
| valid_ids.append(idx) | |
| self.ids = valid_ids | |
| print("=" * 30) | |
| print("Number of valid ids: ", len(self.ids)) | |
| print("=" * 30) | |
| def __getitem__(self, idx: int): | |
| uid = self.ids[idx] | |
| idx_list = torch.arange(self.n_views) | |
| latents = torch.load(self.root_dir / f"{uid}.pt") | |
| clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt") | |
| if self.random_front: | |
| idx_list = torch.roll(idx_list, np.random.randint(self.n_views)) | |
| latents = latents[idx_list] | |
| clip_emb = clip_emb[idx_list][0] | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| cond = latents[0] | |
| data = { | |
| "latents": latents, | |
| "cond_frames_without_noise": clip_emb, | |
| "cond_frames": cond + cond_aug * torch.randn_like(cond), | |
| "fps_id": torch.as_tensor([self.fps_id] * self.n_views), | |
| "motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views), | |
| "cond_aug": torch.as_tensor([cond_aug] * self.n_views), | |
| "num_video_frames": self.n_views, | |
| "image_only_indicator": torch.as_tensor([0.0] * self.n_views), | |
| } | |
| return data | |
| def __len__(self): | |
| return len(self.ids) | |
| class ObjaverseSpiralDataset(LightningDataModule): | |
| def __init__( | |
| self, | |
| root_dir, | |
| random_front=False, | |
| batch_size=2, | |
| num_workers=10, | |
| prefetch_factor=2, | |
| shuffle=True, | |
| max_item=None, | |
| dataset_cls="richdreamer", | |
| reso: int = 256, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.prefetch_factor = prefetch_factor | |
| self.shuffle = shuffle | |
| self.max_item = max_item | |
| self.transform = Compose( | |
| [ | |
| blend_white_bg, | |
| Resize((reso, reso)), | |
| ToTensor(), | |
| Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| data_cls = { | |
| "richdreamer": ObjaverseSpiral, | |
| "lvis": ObjaverseLVISSpiral, | |
| "shengshu_all": ObjaverseALLSpiral, | |
| "latent": LatentObjaverse, | |
| "gobjaverse": GObjaverse, | |
| }[dataset_cls] | |
| self.train_dataset = data_cls( | |
| root_dir=root_dir, | |
| split="train", | |
| random_front=random_front, | |
| transform=self.transform, | |
| max_item=self.max_item, | |
| **kwargs, | |
| ) | |
| self.test_dataset = data_cls( | |
| root_dir=root_dir, | |
| split="val", | |
| random_front=random_front, | |
| transform=self.transform, | |
| max_item=self.max_item, | |
| **kwargs, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| prefetch_factor=self.prefetch_factor, | |
| collate_fn=video_collate_fn | |
| if not hasattr(self.train_dataset, "collate_fn") | |
| else self.train_dataset.collate_fn, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| prefetch_factor=self.prefetch_factor, | |
| collate_fn=video_collate_fn | |
| if not hasattr(self.test_dataset, "collate_fn") | |
| else self.train_dataset.collate_fn, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| prefetch_factor=self.prefetch_factor, | |
| collate_fn=video_collate_fn | |
| if not hasattr(self.test_dataset, "collate_fn") | |
| else self.train_dataset.collate_fn, | |
| ) | |