Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import json | |
| import re | |
| import cv2 | |
| from dataclasses import dataclass, field | |
| import random | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader, Dataset | |
| from PIL import Image | |
| from step1x3d_geometry.utils.typing import * | |
| class BaseDataModuleConfig: | |
| root_dir: str = None | |
| batch_size: int = 4 | |
| num_workers: int = 8 | |
| ################################# General argumentation ################################# | |
| random_flip: bool = ( | |
| False # whether to randomly flip the input point cloud and the input images | |
| ) | |
| ################################# Geometry part ################################# | |
| load_geometry: bool = True # whether to load geometry data | |
| with_sharp_data: bool = False | |
| geo_data_type: str = "sdf" # occupancy, sdf | |
| # for occupancy or sdf supervision | |
| n_samples: int = 4096 # number of points in input point cloud | |
| upsample_ratio: int = 1 # upsample ratio for input point cloud | |
| sampling_strategy: Optional[str] = ( | |
| "random" # sampling strategy for input point cloud | |
| ) | |
| scale: float = 1.0 # scale of the input point cloud and target supervision | |
| noise_sigma: float = 0.0 # noise level of the input point cloud | |
| rotate_points: bool = ( | |
| False # whether to rotate the input point cloud and the supervision, for VAE aug. | |
| ) | |
| load_geometry_supervision: bool = False # whether to load supervision | |
| supervision_type: str = "sdf" # occupancy, sdf, tsdf, tsdf_w_surface | |
| n_supervision: int = 10000 # number of points in supervision | |
| tsdf_threshold: float = ( | |
| 0.01 # threshold for truncating sdf values, used when input is sdf | |
| ) | |
| ################################# Image part ################################# | |
| load_image: bool = False # whether to load images | |
| image_type: str = "rgb" # rgb, normal, rgb_or_normal | |
| image_file_type: str = "png" # png, jpeg | |
| image_type_ratio: float = ( | |
| 1.0 # ratio of rgb for each dataset when image_type is "rgb_or_normal" | |
| ) | |
| crop_image: bool = True # whether to crop the input image | |
| random_color_jitter: bool = ( | |
| False # whether to randomly color jitter the input images | |
| ) | |
| random_rotate: bool = ( | |
| False # whether to randomly rotate the input images, default [-10 deg, 10 deg] | |
| ) | |
| random_mask: bool = False # whether to add random mask to the input image | |
| background_color: Tuple[int, int, int] = field( | |
| default_factory=lambda: (255, 255, 255) | |
| ) | |
| idx: Optional[List[int]] = None # index of the image to load | |
| n_views: int = 1 # number of views | |
| foreground_ratio: Optional[float] = 0.90 | |
| ################################# Caption part ################################# | |
| load_caption: bool = False # whether to load captions | |
| load_label: bool = False # whether to load labels | |
| class BaseDataset(Dataset): | |
| def __init__(self, cfg: Any, split: str) -> None: | |
| super().__init__() | |
| self.cfg: BaseDataModuleConfig = cfg | |
| self.split = split | |
| self.uids = json.load(open(f"{cfg.root_dir}/{split}.json")) | |
| print(f"Loaded {len(self.uids)} {split} uids") | |
| # add ColorJitter transforms for input images | |
| if self.cfg.random_color_jitter: | |
| self.color_jitter = transforms.ColorJitter( | |
| brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 | |
| ) | |
| # add RandomRotation transforms for input images | |
| if self.cfg.random_rotate: | |
| self.rotate = transforms.RandomRotation( | |
| degrees=10, fill=(*self.cfg.background_color, 0.0) | |
| ) # by default 10 deg | |
| def __len__(self): | |
| return len(self.uids) | |
| def _load_shape_from_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: | |
| if self.cfg.geo_data_type == "sdf": | |
| data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz") | |
| # for input point cloud | |
| surface = data["surface"] | |
| if self.cfg.with_sharp_data: | |
| sharp_surface = data["sharp_surface"] | |
| else: | |
| raise NotImplementedError( | |
| f"Data type {self.cfg.geo_data_type} not implemented" | |
| ) | |
| # random sampling | |
| if self.cfg.sampling_strategy == "random": | |
| rng = np.random.default_rng() | |
| ind = rng.choice( | |
| surface.shape[0], | |
| self.cfg.upsample_ratio * self.cfg.n_samples, | |
| replace=True, | |
| ) | |
| surface = surface[ind] | |
| if self.cfg.with_sharp_data: | |
| sharp_surface = sharp_surface[ind] | |
| elif self.cfg.sampling_strategy == "fps": | |
| import fpsample | |
| kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( | |
| surface[:, :3], self.cfg.n_samples, h=5 | |
| ) | |
| surface = surface[kdline_fps_samples_idx] | |
| if self.cfg.with_sharp_data: | |
| kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( | |
| sharp_surface[:, :3], self.cfg.n_samples, h=5 | |
| ) | |
| sharp_surface = sharp_surface[kdline_fps_samples_idx] | |
| else: | |
| raise NotImplementedError( | |
| f"sampling strategy {self.cfg.sampling_strategy} not implemented" | |
| ) | |
| # rescale data | |
| surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale | |
| if self.cfg.with_sharp_data: | |
| sharp_surface[:, :3] = sharp_surface[:, :3] * self.cfg.scale # target scale | |
| ret = { | |
| "uid": self.uids[index].split("/")[-1], | |
| "surface": surface.astype(np.float32), | |
| "sharp_surface": sharp_surface.astype(np.float32), | |
| } | |
| else: | |
| ret = { | |
| "uid": self.uids[index].split("/")[-1], | |
| "surface": surface.astype(np.float32), | |
| } | |
| return ret | |
| def _load_shape_supervision_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: | |
| # for supervision | |
| ret = {} | |
| if self.cfg.geo_data_type == "sdf": | |
| data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz") | |
| data = np.concatenate( | |
| [data["volume_rand_points"], data["near_surface_points"]], axis=0 | |
| ) | |
| rand_points, sdfs = data[:, :3], data[:, 3:] | |
| else: | |
| raise NotImplementedError( | |
| f"Data type {self.cfg.geo_data_type} not implemented" | |
| ) | |
| # random sampling | |
| rng = np.random.default_rng() | |
| ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False) | |
| rand_points = rand_points[ind] | |
| rand_points = rand_points * self.cfg.scale | |
| ret["rand_points"] = rand_points.astype(np.float32) | |
| if self.cfg.geo_data_type == "sdf": | |
| if self.cfg.supervision_type == "sdf": | |
| ret["sdf"] = sdfs[ind].flatten().astype(np.float32) | |
| elif self.cfg.supervision_type == "occupancy": | |
| ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype( | |
| np.float32 | |
| ) | |
| elif self.cfg.supervision_type == "tsdf": | |
| ret["sdf"] = ( | |
| sdfs[ind] | |
| .flatten() | |
| .astype(np.float32) | |
| .clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) | |
| / self.cfg.tsdf_threshold | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| f"Supervision type {self.cfg.supervision_type} not implemented" | |
| ) | |
| return ret | |
| def _load_image(self, index: int) -> Dict[str, Any]: | |
| def _process_img(image, background_color=(255, 255, 255), foreground_ratio=0.9): | |
| alpha = image.getchannel("A") | |
| background = Image.new("RGBA", image.size, (*background_color, 255)) | |
| image = Image.alpha_composite(background, image) | |
| image = image.crop(alpha.getbbox()) | |
| new_size = tuple(int(dim * foreground_ratio) for dim in image.size) | |
| resized_image = image.resize(new_size) | |
| padded_image = Image.new("RGBA", image.size, (*background_color, 255)) | |
| paste_position = ( | |
| (image.width - resized_image.width) // 2, | |
| (image.height - resized_image.height) // 2, | |
| ) | |
| padded_image.paste(resized_image, paste_position) | |
| # Expand image to 1:1 | |
| max_dim = max(padded_image.size) | |
| image = Image.new("RGBA", (max_dim, max_dim), (*background_color, 255)) | |
| paste_position = ( | |
| (max_dim - padded_image.width) // 2, | |
| (max_dim - padded_image.height) // 2, | |
| ) | |
| image.paste(padded_image, paste_position) | |
| image = image.resize((512, 512)) | |
| return image.convert("RGB"), alpha | |
| ret = {} | |
| if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": | |
| assert ( | |
| self.cfg.n_views == 1 | |
| ), "Only single view is supported for single image" | |
| sel_idx = random.choice(self.cfg.idx) | |
| ret["sel_image_idx"] = sel_idx | |
| if self.cfg.image_type == "rgb": | |
| img_path = ( | |
| f"{self.cfg.root_dir}/images/" | |
| + "/".join(self.uids[index].split("/")[-2:]) | |
| + f"/{'{:04d}'.format(sel_idx)}_rgb.{self.cfg.image_file_type}" | |
| ) | |
| elif self.cfg.image_type == "normal": | |
| img_path = ( | |
| f"{self.cfg.root_dir}/images/" | |
| + "/".join(self.uids[index].split("/")[-2:]) | |
| + f"/{'{:04d}'.format(sel_idx)}_normal.{self.cfg.image_file_type}" | |
| ) | |
| image = Image.open(img_path).copy() | |
| # add random color jitter | |
| if self.cfg.random_color_jitter: | |
| rgb = self.color_jitter(image.convert("RGB")) | |
| image = Image.merge("RGBA", (*rgb.split(), image.getchannel("A"))) | |
| # add random rotation | |
| if self.cfg.random_rotate: | |
| image = self.rotate(image) | |
| # add crop | |
| if self.cfg.crop_image: | |
| background_color = ( | |
| torch.randint(0, 256, (3,)) | |
| if self.cfg.background_color is None | |
| else torch.as_tensor(self.cfg.background_color) | |
| ) | |
| image, alpha = _process_img( | |
| image, background_color, self.cfg.foreground_ratio | |
| ) | |
| else: | |
| alpha = image.getchannel("A") | |
| background = Image.new("RGBA", image.size, background_color) | |
| image = Image.alpha_composite(background, image).convert("RGB") | |
| ret["image"] = torch.from_numpy(np.array(image) / 255.0) | |
| ret["mask"] = torch.from_numpy(np.array(alpha) / 255.0).unsqueeze(0) | |
| else: | |
| raise NotImplementedError( | |
| f"Image type {self.cfg.image_type} not implemented" | |
| ) | |
| return ret | |
| def _get_data(self, index): | |
| ret = {"uid": self.uids[index]} | |
| # random flip | |
| flip = np.random.rand() < 0.5 if self.cfg.random_flip else False | |
| # load geometry | |
| if self.cfg.load_geometry: | |
| if self.cfg.geo_data_type == "occupancy" or self.cfg.geo_data_type == "sdf": | |
| # load shape | |
| ret = self._load_shape_from_occupancy_or_sdf(index) | |
| # load supervision for shape | |
| if self.cfg.load_geometry_supervision: | |
| ret.update(self._load_shape_supervision_occupancy_or_sdf(index)) | |
| else: | |
| raise NotImplementedError( | |
| f"Geo data type {self.cfg.geo_data_type} not implemented" | |
| ) | |
| if flip: # random flip the input point cloud and the supervision | |
| for key in ret.keys(): | |
| if key in ["surface", "sharp_surface"]: # N x (xyz + normal) | |
| ret[key][:, 0] = -ret[key][:, 0] | |
| ret[key][:, 3] = -ret[key][:, 3] | |
| elif key in ["rand_points"]: | |
| ret[key][:, 0] = -ret[key][:, 0] | |
| # load image | |
| if self.cfg.load_image: | |
| ret.update(self._load_image(index)) | |
| if flip: # random flip the input image | |
| for key in ret.keys(): | |
| if key in ["image"]: # random flip the input image | |
| ret[key] = torch.flip(ret[key], [2]) | |
| if key in ["mask"]: # random flip the input image | |
| ret[key] = torch.flip(ret[key], [2]) | |
| # load caption | |
| meta = None | |
| if self.cfg.load_caption: | |
| with open(f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r") as f: | |
| meta = json.load(f) | |
| ret.update({"caption": meta["caption"]}) | |
| # load label | |
| if self.cfg.load_label: | |
| if meta is None: | |
| with open( | |
| f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r" | |
| ) as f: | |
| meta = json.load(f) | |
| ret.update({"label": [meta["label"]]}) | |
| return ret | |
| def __getitem__(self, index): | |
| try: | |
| return self._get_data(index) | |
| except Exception as e: | |
| print(f"Error in {self.uids[index]}: {e}") | |
| return self.__getitem__(np.random.randint(len(self))) | |
| def collate(self, batch): | |
| from torch.utils.data._utils.collate import default_collate_fn_map | |
| return torch.utils.data.default_collate(batch) | |