import sys import json import numpy as np from PIL import Image from torch.amp import autocast import torch import copy from torch.nn import functional as F import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Poly3DCollection sys.path.append("./extern/dust3r") from dust3r.inference import inference, load_model from dust3r.utils.image import load_images from dust3r.image_pairs import make_pairs from dust3r.cloud_opt import global_aligner, GlobalAlignerMode def visualize_surfels( surfels, draw_normals=False, normal_scale=20, disk_resolution=16, disk_alpha=0.5 ): """ Visualize surfels as 2D disks oriented by their normals in 3D using matplotlib. Args: surfels (list of Surfel): Each Surfel has at least: - position: (x, y, z) - normal: (nx, ny, nz) - radius: scalar - color: (R, G, B) in [0..255] (optional) draw_normals (bool): If True, draws the surfel normals as quiver arrows. normal_scale (float): Scale factor for the normal arrows. disk_resolution (int): Number of segments to approximate each disk. disk_alpha (float): Alpha (transparency) for the filled disks. """ fig = plt.figure() ax = fig.add_subplot(111, projection='3d') # Prepare arrays for optional quiver (if draw_normals=True) positions = [] normals = [] # We'll accumulate 3D polygons in a list for Poly3DCollection polygons = [] polygon_colors = [] for s in surfels: # --- Extract surfel data --- position = s.position normal = s.normal radius = s.radius if isinstance(position, torch.Tensor): x, y, z = position.detach().cpu().numpy() nx, ny, nz = normal.detach().cpu().numpy() radius = radius.detach().cpu().numpy() else: x, y, z = position nx, ny, nz = normal radius = radius # Convert color from [0..255] to [0..1], or use default if s.color is None: color = (0.2, 0.6, 1.0) # Light blue else: r, g, b = s.color color = (r/255.0, g/255.0, b/255.0) # --- Build local coordinate axes for the disk --- normal = np.array([nx, ny, nz], dtype=float) norm_len = np.linalg.norm(normal) # Skip degenerate normals to avoid nan if norm_len < 1e-12: continue normal /= norm_len # Pick an 'up' vector that is not too close to the normal # so we can build a tangent plane up = np.array([0, 0, 1], dtype=float) if abs(normal.dot(up)) > 0.9: up = np.array([0, 1, 0], dtype=float) # xAxis = normal x up xAxis = np.cross(normal, up) xAxis /= np.linalg.norm(xAxis) # yAxis = normal x xAxis yAxis = np.cross(normal, xAxis) yAxis /= np.linalg.norm(yAxis) # --- Create a circle of 'disk_resolution' segments in local 2D coords --- angles = np.linspace(0, 2*np.pi, disk_resolution, endpoint=False) circle_points_3d = [] for theta in angles: # local 2D circle: (r*cosθ, r*sinθ) px = radius * np.cos(theta) py = radius * np.sin(theta) # transform to 3D world space: position + px*xAxis + py*yAxis world_pt = np.array([x, y, z]) + px * xAxis + py * yAxis circle_points_3d.append(world_pt) # We have a list of [x, y, z]. For a filled polygon, Poly3DCollection # wants them as a single Nx3 array. circle_points_3d = np.array(circle_points_3d) polygons.append(circle_points_3d) polygon_colors.append(color) # Collect positions and normals for quiver (if used) positions.append([x, y, z]) normals.append(normal) # --- Draw the disks as polygons --- poly_collection = Poly3DCollection( polygons, facecolors=polygon_colors, edgecolors='k', # black edge linewidths=0.5, alpha=disk_alpha ) ax.add_collection3d(poly_collection) # --- Optionally draw normal vectors (quiver) --- if draw_normals and len(positions) > 0: X = [p[0] for p in positions] Y = [p[1] for p in positions] Z = [p[2] for p in positions] Nx = [n[0] for n in normals] Ny = [n[1] for n in normals] Nz = [n[2] for n in normals] # Note: If your scene is large, you may want to increase `length`. ax.quiver( X, Y, Z, Nx, Ny, Nz, length=normal_scale, color='red', normalize=True ) # --- Axis labels, aspect ratio, etc. --- ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') try: ax.set_box_aspect((1, 1, 1)) except AttributeError: pass # older MPL versions plt.title("Surfels as Disks (Oriented by Normal)") plt.show() def visualize_pointcloud( points, colors=None, title='Point Cloud', point_size=1, alpha=1.0, bg_color=(240/255, 223/255, 223/255) # 新增参数,默认白色 (1,1,1) ): """ 可视化3D点云,同时支持每个点的RGB或RGBA颜色,并保证x, y, z三个轴等比例缩放。 参数 ---------- points : np.ndarray 或 torch.Tensor 形状为 [N, 3] 的数组或张量,每行表示一个3D点 (x, y, z)。 colors : None, str, 或 np.ndarray - 如果为 None,则使用默认颜色 'blue'。 - 如果为字符串,则所有点均使用该颜色。 - 如果为数组,则形状应为 [N, 3] 或 [N, 4],表示每个点的颜色,值的范围应为 [0, 1](若为浮点数)。 title : str, 可选 图像标题,默认 'Point Cloud'。 point_size : float, 可选 点的大小,默认 1。 alpha : float, 可选 点的整体透明度,默认 1.0。 bg_color : tuple, 可选 背景颜色,格式为 (r, g, b),每个值的范围为 [0, 1],默认为白色 (1.0, 1.0, 1.0)。 示例 -------- >>> import numpy as np >>> pts = np.random.rand(1000, 3) >>> cols = np.random.rand(1000, 3) >>> visualize_pointcloud(pts, colors=cols, title="随机点云", bg_color=(0.2, 0.2, 0.3)) """ # 如果是 Torch 张量,则转换为 NumPy 数组 if isinstance(points, torch.Tensor): points = points.detach().cpu().numpy() if isinstance(colors, torch.Tensor): colors = colors.detach().cpu().numpy() # 如果点云或颜色数据维度过高,则展平 if len(points.shape) > 2: points = points.reshape(-1, 3) if colors is not None and isinstance(colors, np.ndarray) and len(colors.shape) > 2: colors = colors.reshape(-1, colors.shape[-1]) # 验证点云形状 if points.shape[1] != 3: raise ValueError("`points` array must have shape [N, 3].") # 处理颜色参数 if colors is None: colors = 'blue' elif isinstance(colors, np.ndarray): colors = np.asarray(colors) if colors.shape[0] != points.shape[0]: raise ValueError("Colors array length must match the number of points.") if colors.shape[1] not in [3, 4]: raise ValueError("Colors array must have shape [N, 3] or [N, 4].") # 验证背景颜色参数 if not isinstance(bg_color, tuple) or len(bg_color) != 3: raise ValueError("Background color must be a tuple of (r, g, b) with values between 0 and 1.") # 提取坐标 x = points[:, 0] y = points[:, 1] z = points[:, 2] # 创建图像,并设置自定义背景颜色 fig = plt.figure(figsize=(8, 6), facecolor=bg_color) ax = fig.add_subplot(111, projection='3d') ax.set_facecolor(bg_color) # 绘制散点图 ax.scatter(x, y, z, c=colors, s=point_size, alpha=alpha) # 设置等比例缩放 max_range = np.array([x.max() - x.min(), y.max() - y.min(), z.max() - z.min()]).max() / 2.0 mid_x = (x.max() + x.min()) * 0.5 mid_y = (y.max() + y.min()) * 0.5 mid_z = (z.max() + z.min()) * 0.5 ax.set_xlim(mid_x - max_range, mid_x + max_range) ax.set_ylim(mid_y - max_range, mid_y + max_range) ax.set_zlim(mid_z - max_range, mid_z + max_range) # 隐藏刻度和标签 ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) ax.set_xlabel('') ax.set_ylabel('') ax.set_zlabel('') ax.grid(False) # 隐藏3D坐标轴的面板(pane)来去除轴的显示 ax.xaxis.pane.set_visible(False) ax.yaxis.pane.set_visible(False) ax.zaxis.pane.set_visible(False) # 设置标题(如果需要显示标题) ax.set_title(title) plt.tight_layout() plt.show() # def visualize_pointcloud( # points, # colors=None, # title='Point Cloud', # point_size=1, # alpha=1.0 # ): # """ # 可视化3D点云,同时支持每个点的RGB或RGBA颜色,并保证x, y, z三个轴等比例缩放。 # 参数 # ---------- # points : np.ndarray 或 torch.Tensor # 形状为 [N, 3] 的数组或张量,每行表示一个3D点 (x, y, z)。 # colors : None, str, 或 np.ndarray # - 如果为 None,则使用默认颜色 'blue'。 # - 如果为字符串,则所有点均使用该颜色。 # - 如果为数组,则形状应为 [N, 3] 或 [N, 4],表示每个点的颜色,值的范围应为 [0, 1](若为浮点数)。 # title : str, 可选 # 图像标题,默认 'Point Cloud'。 # point_size : float, 可选 # 点的大小,默认 1。 # alpha : float, 可选 # 点的整体透明度,默认 1.0。 # 示例 # -------- # >>> import numpy as np # >>> pts = np.random.rand(1000, 3) # >>> cols = np.random.rand(1000, 3) # >>> visualize_pointcloud(pts, colors=cols, title="随机点云") # """ # # 如果是 Torch 张量,则转换为 NumPy 数组 # if isinstance(points, torch.Tensor): # points = points.detach().cpu().numpy() # if isinstance(colors, torch.Tensor): # colors = colors.detach().cpu().numpy() # # 如果点云或颜色数据维度过高,则展平 # if len(points.shape) > 2: # points = points.reshape(-1, 3) # if colors is not None and isinstance(colors, np.ndarray) and len(colors.shape) > 2: # colors = colors.reshape(-1, colors.shape[-1]) # # 验证点云形状 # if points.shape[1] != 3: # raise ValueError("`points` array must have shape [N, 3].") # # 处理颜色参数 # if colors is None: # colors = 'blue' # elif isinstance(colors, np.ndarray): # colors = np.asarray(colors) # if colors.shape[0] != points.shape[0]: # raise ValueError("Colors array length must match the number of points.") # if colors.shape[1] not in [3, 4]: # raise ValueError("Colors array must have shape [N, 3] or [N, 4].") # # 提取坐标 # x = points[:, 0] # y = points[:, 1] # z = points[:, 2] # # 创建图像,并设置背景为白色 # fig = plt.figure(figsize=(8, 6), facecolor='white') # ax = fig.add_subplot(111, projection='3d') # ax.set_facecolor('white') # # 绘制散点图 # ax.scatter(x, y, z, c=colors, s=point_size, alpha=alpha) # # 设置等比例缩放 # max_range = np.array([x.max() - x.min(), # y.max() - y.min(), # z.max() - z.min()]).max() / 2.0 # mid_x = (x.max() + x.min()) * 0.5 # mid_y = (y.max() + y.min()) * 0.5 # mid_z = (z.max() + z.min()) * 0.5 # ax.set_xlim(mid_x - max_range, mid_x + max_range) # ax.set_ylim(mid_y - max_range, mid_y + max_range) # ax.set_zlim(mid_z - max_range, mid_z + max_range) # # 隐藏刻度和标签 # ax.set_xticks([]) # ax.set_yticks([]) # ax.set_zticks([]) # ax.set_xlabel('') # ax.set_ylabel('') # ax.set_zlabel('') # ax.grid(False) # # 隐藏3D坐标轴的面板(pane)来去除轴的显示 # ax.xaxis.pane.set_visible(False) # ax.yaxis.pane.set_visible(False) # ax.zaxis.pane.set_visible(False) # # 设置标题(如果需要显示标题) # ax.set_title(title) # plt.tight_layout() # plt.show() class Surfel: def __init__(self, position, normal, radius=1.0, color=None): """ position: (x, y, z) normal: (nx, ny, nz) radius: scalar color: (r, g, b) or None """ self.position = position self.normal = normal self.radius = radius self.color = color def __repr__(self): return (f"Surfel(position={self.position}, " f"normal={self.normal}, radius={self.radius}, " f"color={self.color})") class Octree: def __init__(self, points, indices=None, bbox=None, max_points=10): """ 构建八叉树: - points: 所有点的 numpy 数组,形状为 (N, 3) - indices: 当前节点中点的索引列表 - bbox: 当前节点的包围盒,形式为 (center, half_size),其中半径为正方体半边长 - max_points: 叶子节点允许的最大点数 """ self.points = points if indices is None: indices = np.arange(points.shape[0]) self.indices = indices # 如果没有给定包围盒,则计算所有点的包围盒,保证是一个正方体 if bbox is None: min_bound = points.min(axis=0) max_bound = points.max(axis=0) center = (min_bound + max_bound) / 2 half_size = np.max(max_bound - min_bound) / 2 bbox = (center, half_size) self.center, self.half_size = bbox self.children = [] # 存储子节点 self.max_points = max_points if len(self.indices) > self.max_points: self.subdivide() def subdivide(self): """将当前节点划分为8个子节点""" cx, cy, cz = self.center hs = self.half_size / 2 # 八个象限的偏移量 offsets = np.array([[dx, dy, dz] for dx in (-hs, hs) for dy in (-hs, hs) for dz in (-hs, hs)]) for offset in offsets: child_center = self.center + offset child_indices = [] # 检查每个点是否在子节点的包围盒内 for idx in self.indices: p = self.points[idx] if np.all(np.abs(p - child_center) <= hs): child_indices.append(idx) child_indices = np.array(child_indices) if len(child_indices) > 0: child = Octree(self.points, indices=child_indices, bbox=(child_center, hs), max_points=self.max_points) self.children.append(child) # 划分后,内部节点不再直接保存点索引 self.indices = None def sphere_intersects_node(self, center, r): """ 判断以center为球心, r为半径的球是否与当前节点的轴对齐包围盒相交。 算法:计算球心到盒子的距离(只考虑超出盒子边界的部分),若小于r,则相交。 """ diff = np.abs(center - self.center) max_diff = diff - self.half_size max_diff = np.maximum(max_diff, 0) dist_sq = np.sum(max_diff**2) return dist_sq <= r*r def query_ball_point(self, point, r): """ 查询距离给定点 point 小于 r 的所有点索引。 """ results = [] if not self.sphere_intersects_node(point, r): return results # 如果当前节点没有子节点,则为叶子节点 if len(self.children) == 0: if self.indices is not None: for idx in self.indices: if np.linalg.norm(self.points[idx] - point) <= r: results.append(idx) return results else: for child in self.children: results.extend(child.query_ball_point(point, r)) return results def estimate_normal_from_pointmap(pointmap: torch.Tensor) -> torch.Tensor: """ Estimate surface normals from a 3D point map by computing cross products of neighboring points, using PyTorch tensors. Parameters ---------- pointmap : torch.Tensor A PyTorch tensor of shape [H, W, 3] containing 3D points in camera coordinates. Each point is represented as (X, Y, Z). This tensor can be on CPU or GPU. Returns ------- torch.Tensor A PyTorch tensor of shape [H, W, 3] containing estimated surface normals. Each normal is a unit vector (X, Y, Z). Points where normals cannot be computed (e.g. boundaries) will be zero vectors. """ # pointmap is shape (H, W, 3) h, w = pointmap.shape[:2] device = pointmap.device # Keep the device (CPU/GPU) consistent dtype = pointmap.dtype # Initialize the normal map normal_map = torch.zeros((h, w, 3), device=device, dtype=dtype) for y in range(h): for x in range(w): # Check if neighbors are within bounds if x+1 >= w or y+1 >= h: continue p_center = pointmap[y, x] p_right = pointmap[y, x+1] p_down = pointmap[y+1, x] # Compute vectors v1 = p_right - p_center v2 = p_down - p_center v1 = v1 / torch.linalg.norm(v1) v2 = v2 / torch.linalg.norm(v2) # Cross product in camera coordinates n_c = torch.cross(v1, v2) # n_c *= 1e10 # Compute norm of the normal vector norm_len = torch.linalg.norm(n_c) if norm_len < 1e-8: continue # Normalize and store normal_map[y, x] = n_c / norm_len return normal_map def load_multiple_images(image_names, image_size=512, dtype=torch.float32): images = load_images(image_names, size=image_size, force_1024=True, dtype=dtype) img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # Just for reference return images, img_ori def load_initial_images(image_name): images = load_images([image_name], size=512, force_1024=True) img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # [H, W, 3], range [0,1] if len(images) == 1: images = [images[0], copy.deepcopy(images[0])] images[1]['idx'] = 1 return images, img_ori def merge_surfels( new_surfels: list, current_timestamp: str, existing_surfels: list, existing_surfel_to_timestamp: dict, position_threshold: float = 0.025, normal_threshold: float = 0.7, max_points_per_node: int = 10 # 八叉树叶子节点允许的最大点数 ): """ 将新的 surfel 合并到已有 surfel 列表中,使用八叉树来加速空间查找。 Args: new_surfels (list[Surfel]): 待合并的新 surfel 列表。 current_timestamp (str): 当前的时间戳。 existing_surfels (list[Surfel]): 已存在的 surfel 列表。 existing_surfel_to_timestamp (dict): 每个 surfel 索引到时间戳的映射。 position_threshold (float): 判断两个 surfel 空间距离是否足够近的阈值。 normal_threshold (float): 判断两个 surfel 法向是否对齐的阈值。 max_points_per_node (int): 构建八叉树时,每个叶子节点最大允许的点数。 Returns: (list[Surfel], dict): - 未能匹配的 surfel 列表,需要追加到已有 surfel 列表中。 - 更新后的 existing_surfel_to_timestamp 映射。 """ # 安全检查 assert len(existing_surfels) == len(existing_surfel_to_timestamp), ( "existing_surfels 和 existing_surfel_to_timestamp 长度不匹配。" ) # 构造已有 surfel 的位置和法向数组 positions = np.array([s.position for s in existing_surfels]) # Shape: (N, 3) normals = np.array([s.normal for s in existing_surfels]) # Shape: (N, 3) # 用于存储未匹配到已有 surfel 的新 surfel filtered_surfels = [] merge_count = 0 for new_surfel in new_surfels: is_merged = False for idx in range(len(positions)): if np.linalg.norm(positions[idx] - new_surfel.position) < position_threshold: if np.dot(normals[idx], new_surfel.normal) > normal_threshold: existing_surfel_to_timestamp[idx].append(current_timestamp) is_merged = True merge_count += 1 break if not is_merged: filtered_surfels.append(new_surfel) # 返回未匹配的 surfel 列表及更新后的时间戳映射 print(f"merge_count: {merge_count}") return filtered_surfels, existing_surfel_to_timestamp def pointmap_to_surfels(pointmap: torch.Tensor, focal_lengths: torch.Tensor, depth_map: torch.Tensor, poses: torch.Tensor, # shape: (4, 4) radius_scale: float = 0.5, depth_threshold: float = 1.0, estimate_normals: bool = True): surfels = [] if len(focal_lengths) == 2: focal_lengths = torch.mean(focal_lengths, dim=0) H, W = pointmap.shape[:2] # 1) Estimate normals if estimate_normals: normal_map = estimate_normal_from_pointmap(pointmap) else: normal_map = torch.zeros_like(pointmap) depth_remove_count = 0 for v in range(H-1): for u in range(W-1): if depth_map[v, u] > depth_threshold: depth_remove_count += 1 continue position = pointmap[v, u].detach().cpu().numpy() # in global coords normal = normal_map[v, u].detach().cpu().numpy() # in global coords depth = depth_map[v, u].detach().cpu().numpy() # in local coords view_direction = position - poses[0:3, 3].detach().cpu().numpy() view_direction = view_direction / np.linalg.norm(view_direction) if np.dot(view_direction, normal) < 0: normal = -normal adjustment_value = 0.2 + 0.8 * np.abs(np.dot(view_direction, normal)) radius = (radius_scale * depth/focal_lengths/adjustment_value).detach().cpu().numpy() surfels.append(Surfel(position, normal, radius)) print(f"depth_remove_count: {depth_remove_count}") return surfels def run_dust3r(input_images, dust3r, batch_size = 1, niter = 1000, lr = 0.01, schedule = 'linear', clean_pc = False, focal_lengths = None, poses = None, device = 'cuda', background_mask = None, use_amp = False # <<< AMP CHANGE: add a flag to enable/disable AMP ): # We wrap the entire inference and alignment in autocast so that # forward passes and any internal backward passes happen in mixed precision. with autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp): pairs = make_pairs(input_images, scene_graph='complete', prefilter=None, symmetrize=True) output = inference(pairs, dust3r, device, batch_size=batch_size) mode = GlobalAlignerMode.PointCloudDifferentFocalOptimizer scene = global_aligner(output, device=device, mode=mode) if focal_lengths is not None: scene.preset_focal(focal_lengths) if poses is not None: scene.preset_pose(poses) if mode == GlobalAlignerMode.PointCloudDifferentFocalOptimizer: # Depending on how dust3r internally does optimization, # it may or may not require gradient scaling. # If you need it, you can do something more manual with GradScaler. loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) else: loss = None # If you want to clean up the pointcloud after alignment if clean_pc: scene = scene.clean_pointcloud() return scene, loss if __name__ == "__main__": load_image_size = 512 load_dtype = torch.float16 device = 'cuda' model_path = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" selected_frame_paths = ["assets/jesus/jesus_0.jpg", "assets/jesus/jesus_1.jpg", "assets/jesus/jesus_2.jpg" ] # pil_image = Image.open("./assets/radcliffe_camera_bg.png").resize((512, 288)) # r, g, b, a = pil_image.split() # background_mask = a # background_mask = (1 - torch.tensor(np.array(background_mask))).unsqueeze(0).repeat(2, 1, 1).bool() all_surfels = [] surfel_to_timestamp = {} dust3r = load_model(model_path, device=device) dust3r.eval() dust3r = dust3r.to(device) dust3r = dust3r.half() if len(selected_frame_paths) == 1: selected_frame_paths = selected_frame_paths * 2 frame_images, frame_img_ori = load_multiple_images(selected_frame_paths, image_size=load_image_size, dtype=load_dtype) scene, loss = run_dust3r(frame_images, dust3r, device=device, use_amp=True) # --- 1) Extract outputs --- # pointcloud shape: [N, H, W, 3] shrink_factor = 0.15 pointcloud = torch.stack(scene.get_pts3d()) # poses shape: [N, 4, 4] # optimized_poses = scene.get_im_poses() # focal_lengths shape: [N] focal_lengths = scene.get_focals() # adjustion_transformation_matrix = SpatialConstructor.estimate_pose_alignment(optimized_poses, original_camera_poses) # optimized_poses -> original_camera_poses matrix # adjusted_optimized_poses = adjustion_transformation_matrix @ optimized_poses # --- 2) Resize pointcloud --- # Permute for resizing -> [N, 3, H, W] pointcloud = pointcloud.permute(0, 3, 1, 2) # Resize using bilinear interpolation pointcloud = F.interpolate( pointcloud, scale_factor=shrink_factor, mode='bilinear' ) # Permute back -> [N, H', W', 3] pointcloud = pointcloud.permute(0, 2, 3, 1)[-1:] # transform pointcloud # pointcloud = torch.stack([SpatialConstructor.transform_pointmap(pointcloud[i], adjustion_transformation_matrix) for i in range(pointcloud.shape[0])]) rgbs = scene.imgs rgbs = torch.tensor(np.array(rgbs)) rgbs = rgbs.permute(0, 3, 1, 2) rgbs = F.interpolate(rgbs, scale_factor=shrink_factor, mode='bilinear') rgbs = rgbs.permute(0, 2, 3, 1)[-1:] visualize_pointcloud(pointcloud, rgbs, point_size=4) # --- 3) Resize depth map --- # depth_map shape: [N, H, W] depth_map = torch.stack(scene.get_depthmaps()) # Add channel dimension -> [N, 1, H, W] depth_map = depth_map.unsqueeze(1) depth_map = F.interpolate( depth_map, scale_factor=shrink_factor, mode='bilinear' ) poses = scene.get_im_poses()[-1:] # Remove channel dimension -> [N, H', W'] depth_map = depth_map.squeeze(1)[-1:] for frame_idx in range(len(pointcloud)): # if frame_idx > 1: # break # Create surfels for the current frame surfels = pointmap_to_surfels( pointmap=pointcloud[frame_idx], focal_lengths=focal_lengths[frame_idx] * shrink_factor, depth_map=depth_map[frame_idx], poses=poses[frame_idx], estimate_normals=True, radius_scale=0.5, depth_threshold=0.48 ) # Merge with existing surfels if not the first frame if frame_idx > 0: surfels, surfel_to_timestamp = merge_surfels( new_surfels=surfels, current_timestamp=frame_idx, existing_surfels=all_surfels, existing_surfel_to_timestamp=surfel_to_timestamp, position_threshold=0.01, normal_threshold=0.7 ) # Update timestamp mapping num_surfels = len(surfels) surfel_start_index = len(all_surfels) for surfel_index in range(num_surfels): # Each newly created surfel gets mapped to this frame index # surfel_to_timestamp[surfel_start_index + surfel_index] = [frame_idx] surfel_to_timestamp[surfel_start_index + surfel_index] = [2] all_surfels.extend(surfels) positions = np.array([s.position for s in all_surfels], dtype=np.float32) normals = np.array([s.normal for s in all_surfels], dtype=np.float32) radii = np.array([s.radius for s in all_surfels], dtype=np.float32) colors = np.array([s.color for s in all_surfels], dtype=np.float32) visualize_surfels(all_surfels) # np.savez(f"./surfels_added_first2.npz", # positions=positions, # normals=normals, # radii=radii, # colors=colors) # with open("surfel_to_timestamp_first2.json", "w") as f: # json.dump(surfel_to_timestamp, f) np.savez(f"./surfels_added_only3.npz", positions=positions, normals=normals, radii=radii, colors=colors) with open("surfel_to_timestamp_only3.json", "w") as f: json.dump(surfel_to_timestamp, f) stop = 1