Spaces:
Runtime error
Runtime error
| """ | |
| Preprocessing Script for ScanNet 20/200 | |
| Author: Xiaoyang Wu ([email protected]) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| import warnings | |
| import torch | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| import sys | |
| import os | |
| import argparse | |
| import glob | |
| import json | |
| import plyfile | |
| import numpy as np | |
| import pandas as pd | |
| import multiprocessing as mp | |
| from concurrent.futures import ProcessPoolExecutor | |
| from itertools import repeat | |
| # Load external constants | |
| from meta_data.scannet200_constants import VALID_CLASS_IDS_200, VALID_CLASS_IDS_20 | |
| CLOUD_FILE_PFIX = '_vh_clean_2' | |
| SEGMENTS_FILE_PFIX = '.0.010000.segs.json' | |
| AGGREGATIONS_FILE_PFIX = '.aggregation.json' | |
| CLASS_IDS200 = VALID_CLASS_IDS_200 | |
| CLASS_IDS20 = VALID_CLASS_IDS_20 | |
| IGNORE_INDEX = -1 | |
| def read_plymesh(filepath): | |
| """Read ply file and return it as numpy array. Returns None if emtpy.""" | |
| with open(filepath, 'rb') as f: | |
| plydata = plyfile.PlyData.read(f) | |
| if plydata.elements: | |
| vertices = pd.DataFrame(plydata['vertex'].data).values | |
| faces = np.stack(plydata['face'].data['vertex_indices'], axis=0) | |
| return vertices, faces | |
| # Map the raw category id to the point cloud | |
| def point_indices_from_group(seg_indices, group, labels_pd): | |
| group_segments = np.array(group['segments']) | |
| label = group['label'] | |
| # Map the category name to id | |
| label_id20 = labels_pd[labels_pd['raw_category'] == label]['nyu40id'] | |
| label_id20 = int(label_id20.iloc[0]) if len(label_id20) > 0 else 0 | |
| label_id200 = labels_pd[labels_pd['raw_category'] == label]['id'] | |
| label_id200 = int(label_id200.iloc[0]) if len(label_id200) > 0 else 0 | |
| # Only store for the valid categories | |
| if label_id20 in CLASS_IDS20: | |
| label_id20 = CLASS_IDS20.index(label_id20) | |
| else: | |
| label_id20 = IGNORE_INDEX | |
| if label_id200 in CLASS_IDS200: | |
| label_id200 = CLASS_IDS200.index(label_id200) | |
| else: | |
| label_id200 = IGNORE_INDEX | |
| # get points, where segment indices (points labelled with segment ids) are in the group segment list | |
| point_idx = np.where(np.isin(seg_indices, group_segments))[0] | |
| return point_idx, label_id20, label_id200 | |
| def face_normal(vertex, face): | |
| v01 = vertex[face[:, 1]] - vertex[face[:, 0]] | |
| v02 = vertex[face[:, 2]] - vertex[face[:, 0]] | |
| vec = np.cross(v01, v02) | |
| length = np.sqrt(np.sum(vec ** 2, axis=1, keepdims=True)) + 1.0e-8 | |
| nf = vec / length | |
| area = length * 0.5 | |
| return nf, area | |
| def vertex_normal(vertex, face): | |
| nf, area = face_normal(vertex, face) | |
| nf = nf * area | |
| nv = np.zeros_like(vertex) | |
| for i in range(face.shape[0]): | |
| nv[face[i]] += nf[i] | |
| length = np.sqrt(np.sum(nv ** 2, axis=1, keepdims=True)) + 1.0e-8 | |
| nv = nv / length | |
| return nv | |
| def handle_process(scene_path, output_path, labels_pd, train_scenes, val_scenes, parse_normals=True): | |
| scene_id = os.path.basename(scene_path) | |
| mesh_path = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}.ply') | |
| segments_file = os.path.join(scene_path, f'{scene_id}{CLOUD_FILE_PFIX}{SEGMENTS_FILE_PFIX}') | |
| aggregations_file = os.path.join(scene_path, f'{scene_id}{AGGREGATIONS_FILE_PFIX}') | |
| info_file = os.path.join(scene_path, f'{scene_id}.txt') | |
| if scene_id in train_scenes: | |
| output_file = os.path.join(output_path, 'train', f'{scene_id}.pth') | |
| split_name = 'train' | |
| elif scene_id in val_scenes: | |
| output_file = os.path.join(output_path, 'val', f'{scene_id}.pth') | |
| split_name = 'val' | |
| else: | |
| output_file = os.path.join(output_path, 'test', f'{scene_id}.pth') | |
| split_name = 'test' | |
| print(f'Processing: {scene_id} in {split_name}') | |
| vertices, faces = read_plymesh(mesh_path) | |
| coords = vertices[:, :3] | |
| colors = vertices[:, 3:6] | |
| save_dict = dict(coord=coords, color=colors, scene_id=scene_id) | |
| # # Rotating the mesh to axis aligned | |
| # info_dict = {} | |
| # with open(info_file) as f: | |
| # for line in f: | |
| # (key, val) = line.split(" = ") | |
| # info_dict[key] = np.fromstring(val, sep=' ') | |
| # | |
| # if 'axisAlignment' not in info_dict: | |
| # rot_matrix = np.identity(4) | |
| # else: | |
| # rot_matrix = info_dict['axisAlignment'].reshape(4, 4) | |
| # r_coords = coords.transpose() | |
| # r_coords = np.append(r_coords, np.ones((1, r_coords.shape[1])), axis=0) | |
| # r_coords = np.dot(rot_matrix, r_coords) | |
| # coords = r_coords | |
| # Parse Normals | |
| if parse_normals: | |
| save_dict["normal"] = vertex_normal(coords, faces) | |
| # Load segments file | |
| if split_name != "test": | |
| with open(segments_file) as f: | |
| segments = json.load(f) | |
| seg_indices = np.array(segments['segIndices']) | |
| # Load Aggregations file | |
| with open(aggregations_file) as f: | |
| aggregation = json.load(f) | |
| seg_groups = np.array(aggregation['segGroups']) | |
| # Generate new labels | |
| semantic_gt20 = np.ones((vertices.shape[0])) * IGNORE_INDEX | |
| semantic_gt200 = np.ones((vertices.shape[0])) * IGNORE_INDEX | |
| instance_ids = np.ones((vertices.shape[0])) * IGNORE_INDEX | |
| for group in seg_groups: | |
| point_idx, label_id20, label_id200 = \ | |
| point_indices_from_group(seg_indices, group, labels_pd) | |
| semantic_gt20[point_idx] = label_id20 | |
| semantic_gt200[point_idx] = label_id200 | |
| instance_ids[point_idx] = group['id'] | |
| semantic_gt20 = semantic_gt20.astype(int) | |
| semantic_gt200 = semantic_gt200.astype(int) | |
| instance_ids = instance_ids.astype(int) | |
| save_dict["semantic_gt20"] = semantic_gt20 | |
| save_dict["semantic_gt200"] = semantic_gt200 | |
| save_dict["instance_gt"] = instance_ids | |
| # Concatenate with original cloud | |
| processed_vertices = np.hstack((semantic_gt200, instance_ids)) | |
| if np.any(np.isnan(processed_vertices)) or not np.all(np.isfinite(processed_vertices)): | |
| raise ValueError(f'Find NaN in Scene: {scene_id}') | |
| # Save processed data | |
| torch.save(save_dict, output_file) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset_root', required=True, help='Path to the ScanNet dataset containing scene folders') | |
| parser.add_argument('--output_root', required=True, help='Output path where train/val folders will be located') | |
| parser.add_argument('--parse_normals', default=True, type=bool, help='Whether parse point normals') | |
| config = parser.parse_args() | |
| # Load label map | |
| labels_pd = pd.read_csv('scannet-preprocess/meta_data/scannetv2-labels.combined.tsv', | |
| sep='\t', header=0) | |
| # Load train/val splits | |
| with open('scannet-preprocess/meta_data/scannetv2_train.txt') as train_file: | |
| train_scenes = train_file.read().splitlines() | |
| with open('scannet-preprocess/meta_data/scannetv2_val.txt') as val_file: | |
| val_scenes = val_file.read().splitlines() | |
| # Create output directories | |
| train_output_dir = os.path.join(config.output_root, 'train') | |
| os.makedirs(train_output_dir, exist_ok=True) | |
| val_output_dir = os.path.join(config.output_root, 'val') | |
| os.makedirs(val_output_dir, exist_ok=True) | |
| test_output_dir = os.path.join(config.output_root, 'test') | |
| os.makedirs(test_output_dir, exist_ok=True) | |
| # Load scene paths | |
| scene_paths = sorted(glob.glob(config.dataset_root + '/scans*/scene*')) | |
| # Preprocess data. | |
| print('Processing scenes...') | |
| pool = ProcessPoolExecutor(max_workers=mp.cpu_count()) | |
| # pool = ProcessPoolExecutor(max_workers=1) | |
| _ = list(pool.map(handle_process, scene_paths, repeat(config.output_root), repeat(labels_pd), repeat(train_scenes), | |
| repeat(val_scenes), repeat(config.parse_normals))) | |