Spaces:
Runtime error
Runtime error
| # | |
| # Copyright (C) 2023, Inria | |
| # GRAPHDECO research group, https://team.inria.fr/graphdeco | |
| # All rights reserved. | |
| # | |
| # This software is free for non-commercial, research and evaluation use | |
| # under the terms of the LICENSE.md file. | |
| # | |
| # For inquiries contact [email protected] | |
| # | |
| import os | |
| import random | |
| import json | |
| from utils.system_utils import searchForMaxIteration | |
| from scene.dataset_readers import sceneLoadTypeCallbacks | |
| from scene.gaussian_model import GaussianModel | |
| from arguments import ModelParams | |
| from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON | |
| class Scene: | |
| gaussians: GaussianModel | |
| def __init__( | |
| self, | |
| args: ModelParams, | |
| gaussians: GaussianModel, | |
| load_iteration=None, | |
| shuffle=True, | |
| resolution_scales=[1.0], | |
| skip_gaussians=False, | |
| ): | |
| """b | |
| :param path: Path to colmap scene main folder. | |
| """ | |
| self.model_path = args.model_path | |
| self.loaded_iter = None | |
| self.gaussians = gaussians | |
| if load_iteration: | |
| if load_iteration == -1: | |
| self.loaded_iter = searchForMaxIteration( | |
| os.path.join(self.model_path, "point_cloud") | |
| ) | |
| else: | |
| self.loaded_iter = load_iteration | |
| print("Loading trained model at iteration {}".format(self.loaded_iter)) | |
| self.train_cameras = {} | |
| self.test_cameras = {} | |
| if os.path.exists(os.path.join(args.source_path, "sparse")): | |
| scene_info = sceneLoadTypeCallbacks["Colmap"]( | |
| args.source_path, args.images, args.eval | |
| ) | |
| elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): | |
| print("Found transforms_train.json file, assuming Blender data set!") | |
| scene_info = sceneLoadTypeCallbacks["Blender"]( | |
| args.source_path, args.white_background, args.eval | |
| ) | |
| elif hasattr(args, "num_frames"): | |
| print("using video-nvs target") | |
| scene_info = sceneLoadTypeCallbacks["VideoNVS"]( | |
| args.num_frames, | |
| args.radius, | |
| args.elevation, | |
| args.fov, | |
| args.reso, | |
| args.images, | |
| args.masks, | |
| args.num_pts, | |
| args.train, | |
| ) | |
| else: | |
| assert False, "Could not recognize scene type!" | |
| if not self.loaded_iter: | |
| with open(scene_info.ply_path, "rb") as src_file, open( | |
| os.path.join(self.model_path, "input.ply"), "wb" | |
| ) as dest_file: | |
| dest_file.write(src_file.read()) | |
| json_cams = [] | |
| camlist = [] | |
| if scene_info.test_cameras: | |
| camlist.extend(scene_info.test_cameras) | |
| if scene_info.train_cameras: | |
| camlist.extend(scene_info.train_cameras) | |
| for id, cam in enumerate(camlist): | |
| json_cams.append(camera_to_JSON(id, cam)) | |
| with open(os.path.join(self.model_path, "cameras.json"), "w") as file: | |
| json.dump(json_cams, file) | |
| if shuffle: | |
| random.shuffle( | |
| scene_info.train_cameras | |
| ) # Multi-res consistent random shuffling | |
| random.shuffle( | |
| scene_info.test_cameras | |
| ) # Multi-res consistent random shuffling | |
| self.cameras_extent = scene_info.nerf_normalization["radius"] | |
| for resolution_scale in resolution_scales: | |
| print("Loading Training Cameras") | |
| self.train_cameras[resolution_scale] = cameraList_from_camInfos( | |
| scene_info.train_cameras, resolution_scale, args | |
| ) | |
| print("Loading Test Cameras") | |
| self.test_cameras[resolution_scale] = cameraList_from_camInfos( | |
| scene_info.test_cameras, resolution_scale, args | |
| ) | |
| if not skip_gaussians: | |
| if self.loaded_iter: | |
| self.gaussians.load_ply( | |
| os.path.join( | |
| self.model_path, | |
| "point_cloud", | |
| "iteration_" + str(self.loaded_iter), | |
| "point_cloud.ply", | |
| ) | |
| ) | |
| else: | |
| self.gaussians.create_from_pcd( | |
| scene_info.point_cloud, self.cameras_extent | |
| ) | |
| def save(self, iteration): | |
| point_cloud_path = os.path.join( | |
| self.model_path, "point_cloud/iteration_{}".format(iteration) | |
| ) | |
| self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) | |
| def getTrainCameras(self, scale=1.0): | |
| return self.train_cameras[scale] | |
| def getTestCameras(self, scale=1.0): | |
| return self.test_cameras[scale] | |