Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import numpy as np | |
| import webdataset as wds | |
| import pytorch_lightning as pl | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from src.utils.train_util import instantiate_from_config | |
| class DataModuleFromConfig(pl.LightningDataModule): | |
| def __init__( | |
| self, | |
| batch_size=8, | |
| num_workers=4, | |
| train=None, | |
| validation=None, | |
| test=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.dataset_configs = dict() | |
| if train is not None: | |
| self.dataset_configs['train'] = train | |
| if validation is not None: | |
| self.dataset_configs['validation'] = validation | |
| if test is not None: | |
| self.dataset_configs['test'] = test | |
| def setup(self, stage): | |
| if stage in ['fit']: | |
| self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) | |
| else: | |
| raise NotImplementedError | |
| def train_dataloader(self): | |
| # sampler = DistributedSampler(self.datasets['train']) | |
| return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) | |
| def val_dataloader(self): | |
| # sampler = DistributedSampler(self.datasets['validation']) | |
| return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False) | |
| def test_dataloader(self): | |
| return wds.WebLoader(self.datasets['validation'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
| class RefinementData(Dataset): | |
| lights_to_caption = { | |
| 0 : "Morning light", | |
| 1 : "Midday light, clear sky", | |
| 2 : "Afternoon light, cloudy sky", | |
| } | |
| def __init__(self, | |
| root_dir='refinement_dataset/', | |
| gt_subpath='gt', | |
| pred_subpath='shap_e', | |
| validation=False, | |
| overfit=False, | |
| caption_path=None, | |
| split_path=None, | |
| single_view=False, | |
| single_light=False, | |
| ) -> None: | |
| self.root_dir = Path(root_dir) | |
| self.gt_subpath = gt_subpath | |
| self.pred_subpath = pred_subpath | |
| self.single_view = single_view | |
| self.single_light = single_light | |
| if caption_path is not None: | |
| caption_path = self.root_dir / caption_path | |
| with open(caption_path) as f: | |
| self.captions_dict = json.load(f) | |
| split_json = self.root_dir / split_path | |
| with open(split_json) as f: | |
| split_dict = json.load(f) | |
| # print(split_dict.keys | |
| # exit(0) | |
| if validation: | |
| uuids = split_dict['val'] | |
| else: | |
| uuids = split_dict['train'] | |
| self.paths = [self.root_dir / uuid for uuid in uuids] | |
| print('============= length of dataset %d =============' % len(self.paths)) | |
| def __len__(self): | |
| return len(self.paths) | |
| def load_im(self, path, color): | |
| pil_img = Image.open(path) | |
| image = np.asarray(pil_img, dtype=np.float32) / 255. | |
| if image.shape[2] == 4: | |
| alpha = image[:, :, 3:] | |
| image = image[:, :, :3] * alpha + color * (1 - alpha) | |
| else: | |
| alpha = np.ones_like(image[:, :, :1]) | |
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() | |
| alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() | |
| return image, alpha | |
| def __getitem__(self, index): | |
| if os.path.exists(self.paths[index] / 'lights.json'): | |
| num_lights = 3 | |
| else: | |
| num_lights = len(os.listdir(self.paths[index] / self.gt_subpath)) | |
| if self.single_view: | |
| view_index = np.random.randint(0, 6) | |
| if self.single_light: | |
| light_index = 0 | |
| else: | |
| light_index = np.random.randint(0, num_lights) | |
| # print("light index", light_index) | |
| # exit(0) | |
| uuid = self.paths[index].name | |
| caption = self.captions_dict[uuid] | |
| # if "lights.json" in os.listdir(self.paths[index]) and num_lights == 3: # according to additions to the dataset | |
| # caption += " " + self.lights_to_caption[light_index] | |
| image_path_gt = os.path.join(self.paths[index],'gt',str(light_index), "latent.pt") | |
| image_path_pred = os.path.join(self.paths[index],'shap_e', "latent.pt") | |
| '''background color, default: white''' | |
| try: | |
| imgs_gt = torch.load(image_path_gt,map_location='cpu').squeeze() | |
| imgs_pred = torch.load(image_path_pred,map_location='cpu').squeeze() | |
| except Exception as e: | |
| print("Error loading latent tensors, gt path %s, pred path %s" % (image_path_gt, image_path_pred)) | |
| raise e | |
| if self.single_view: | |
| row = view_index // 2 | |
| col = view_index % 2 | |
| imgs_gt = imgs_gt[:, row*40:(row+1)*40, col*40:(col+1)*40] | |
| imgs_pred = imgs_pred[:, row*40:(row+1)*40, col*40:(col+1)*40] | |
| # imgs_gt = imgs_gt | |
| data = { | |
| 'refined_imgs': imgs_gt, # (6, 3, H, W) | |
| 'unrefined_imgs': imgs_pred, # (6, 3, H, W) | |
| 'caption': caption, | |
| 'index': index | |
| } | |
| return data | |
| class ObjaverseData(Dataset): | |
| def __init__(self, | |
| root_dir='objaverse/', | |
| meta_fname='valid_paths.json', | |
| image_dir='rendering_zero123plus', | |
| validation=False, | |
| ): | |
| self.root_dir = Path(root_dir) | |
| self.image_dir = image_dir | |
| with open(os.path.join(root_dir, meta_fname)) as f: | |
| lvis_dict = json.load(f) | |
| paths = [] | |
| for k in lvis_dict.keys(): | |
| paths.extend(lvis_dict[k]) | |
| self.paths = paths | |
| total_objects = len(self.paths) | |
| if validation: | |
| self.paths = self.paths[-16:] # used last 16 as validation | |
| else: | |
| self.paths = self.paths[:-16] | |
| print('============= length of dataset %d =============' % len(self.paths)) | |
| def __len__(self): | |
| return len(self.paths) | |
| def load_im(self, path, color): | |
| pil_img = Image.open(path) | |
| image = np.asarray(pil_img, dtype=np.float32) / 255. | |
| alpha = image[:, :, 3:] | |
| image = image[:, :, :3] * alpha + color * (1 - alpha) | |
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() | |
| alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() | |
| return image, alpha | |
| def __getitem__(self, index): | |
| while True: | |
| image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index]) | |
| '''background color, default: white''' | |
| bkg_color = [1., 1., 1.] | |
| img_list = [] | |
| try: | |
| for idx in range(7): | |
| img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color) | |
| img_list.append(img) | |
| except Exception as e: | |
| print(e) | |
| index = np.random.randint(0, len(self.paths)) | |
| continue | |
| break | |
| imgs = torch.stack(img_list, dim=0).float() | |
| data = { | |
| 'cond_imgs': imgs[0], # (3, H, W) | |
| 'target_imgs': imgs[1:], # (6, 3, H, W) | |
| } | |
| return data | |