Spaces:
Running
on
L4
Running
on
L4
import sys | |
import json | |
import numpy as np | |
from PIL import Image | |
from torch.amp import autocast | |
import torch | |
import copy | |
from torch.nn import functional as F | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
sys.path.append("./extern/dust3r") | |
from dust3r.inference import inference, load_model | |
from dust3r.utils.image import load_images | |
from dust3r.image_pairs import make_pairs | |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
def visualize_surfels( | |
surfels, | |
draw_normals=False, | |
normal_scale=20, | |
disk_resolution=16, | |
disk_alpha=0.5 | |
): | |
""" | |
Visualize surfels as 2D disks oriented by their normals in 3D using matplotlib. | |
Args: | |
surfels (list of Surfel): Each Surfel has at least: | |
- position: (x, y, z) | |
- normal: (nx, ny, nz) | |
- radius: scalar | |
- color: (R, G, B) in [0..255] (optional) | |
draw_normals (bool): If True, draws the surfel normals as quiver arrows. | |
normal_scale (float): Scale factor for the normal arrows. | |
disk_resolution (int): Number of segments to approximate each disk. | |
disk_alpha (float): Alpha (transparency) for the filled disks. | |
""" | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
# Prepare arrays for optional quiver (if draw_normals=True) | |
positions = [] | |
normals = [] | |
# We'll accumulate 3D polygons in a list for Poly3DCollection | |
polygons = [] | |
polygon_colors = [] | |
for s in surfels: | |
# --- Extract surfel data --- | |
position = s.position | |
normal = s.normal | |
radius = s.radius | |
if isinstance(position, torch.Tensor): | |
x, y, z = position.detach().cpu().numpy() | |
nx, ny, nz = normal.detach().cpu().numpy() | |
radius = radius.detach().cpu().numpy() | |
else: | |
x, y, z = position | |
nx, ny, nz = normal | |
radius = radius | |
# Convert color from [0..255] to [0..1], or use default | |
if s.color is None: | |
color = (0.2, 0.6, 1.0) # Light blue | |
else: | |
r, g, b = s.color | |
color = (r/255.0, g/255.0, b/255.0) | |
# --- Build local coordinate axes for the disk --- | |
normal = np.array([nx, ny, nz], dtype=float) | |
norm_len = np.linalg.norm(normal) | |
# Skip degenerate normals to avoid nan | |
if norm_len < 1e-12: | |
continue | |
normal /= norm_len | |
# Pick an 'up' vector that is not too close to the normal | |
# so we can build a tangent plane | |
up = np.array([0, 0, 1], dtype=float) | |
if abs(normal.dot(up)) > 0.9: | |
up = np.array([0, 1, 0], dtype=float) | |
# xAxis = normal x up | |
xAxis = np.cross(normal, up) | |
xAxis /= np.linalg.norm(xAxis) | |
# yAxis = normal x xAxis | |
yAxis = np.cross(normal, xAxis) | |
yAxis /= np.linalg.norm(yAxis) | |
# --- Create a circle of 'disk_resolution' segments in local 2D coords --- | |
angles = np.linspace(0, 2*np.pi, disk_resolution, endpoint=False) | |
circle_points_3d = [] | |
for theta in angles: | |
# local 2D circle: (r*cosθ, r*sinθ) | |
px = radius * np.cos(theta) | |
py = radius * np.sin(theta) | |
# transform to 3D world space: position + px*xAxis + py*yAxis | |
world_pt = np.array([x, y, z]) + px * xAxis + py * yAxis | |
circle_points_3d.append(world_pt) | |
# We have a list of [x, y, z]. For a filled polygon, Poly3DCollection | |
# wants them as a single Nx3 array. | |
circle_points_3d = np.array(circle_points_3d) | |
polygons.append(circle_points_3d) | |
polygon_colors.append(color) | |
# Collect positions and normals for quiver (if used) | |
positions.append([x, y, z]) | |
normals.append(normal) | |
# --- Draw the disks as polygons --- | |
poly_collection = Poly3DCollection( | |
polygons, | |
facecolors=polygon_colors, | |
edgecolors='k', # black edge | |
linewidths=0.5, | |
alpha=disk_alpha | |
) | |
ax.add_collection3d(poly_collection) | |
# --- Optionally draw normal vectors (quiver) --- | |
if draw_normals and len(positions) > 0: | |
X = [p[0] for p in positions] | |
Y = [p[1] for p in positions] | |
Z = [p[2] for p in positions] | |
Nx = [n[0] for n in normals] | |
Ny = [n[1] for n in normals] | |
Nz = [n[2] for n in normals] | |
# Note: If your scene is large, you may want to increase `length`. | |
ax.quiver( | |
X, Y, Z, | |
Nx, Ny, Nz, | |
length=normal_scale, | |
color='red', | |
normalize=True | |
) | |
# --- Axis labels, aspect ratio, etc. --- | |
ax.set_xlabel('X') | |
ax.set_ylabel('Y') | |
ax.set_zlabel('Z') | |
try: | |
ax.set_box_aspect((1, 1, 1)) | |
except AttributeError: | |
pass # older MPL versions | |
plt.title("Surfels as Disks (Oriented by Normal)") | |
plt.show() | |
def visualize_pointcloud( | |
points, | |
colors=None, | |
title='Point Cloud', | |
point_size=1, | |
alpha=1.0, | |
bg_color=(240/255, 223/255, 223/255) # 新增参数,默认白色 (1,1,1) | |
): | |
""" | |
可视化3D点云,同时支持每个点的RGB或RGBA颜色,并保证x, y, z三个轴等比例缩放。 | |
参数 | |
---------- | |
points : np.ndarray 或 torch.Tensor | |
形状为 [N, 3] 的数组或张量,每行表示一个3D点 (x, y, z)。 | |
colors : None, str, 或 np.ndarray | |
- 如果为 None,则使用默认颜色 'blue'。 | |
- 如果为字符串,则所有点均使用该颜色。 | |
- 如果为数组,则形状应为 [N, 3] 或 [N, 4],表示每个点的颜色,值的范围应为 [0, 1](若为浮点数)。 | |
title : str, 可选 | |
图像标题,默认 'Point Cloud'。 | |
point_size : float, 可选 | |
点的大小,默认 1。 | |
alpha : float, 可选 | |
点的整体透明度,默认 1.0。 | |
bg_color : tuple, 可选 | |
背景颜色,格式为 (r, g, b),每个值的范围为 [0, 1],默认为白色 (1.0, 1.0, 1.0)。 | |
示例 | |
-------- | |
>>> import numpy as np | |
>>> pts = np.random.rand(1000, 3) | |
>>> cols = np.random.rand(1000, 3) | |
>>> visualize_pointcloud(pts, colors=cols, title="随机点云", bg_color=(0.2, 0.2, 0.3)) | |
""" | |
# 如果是 Torch 张量,则转换为 NumPy 数组 | |
if isinstance(points, torch.Tensor): | |
points = points.detach().cpu().numpy() | |
if isinstance(colors, torch.Tensor): | |
colors = colors.detach().cpu().numpy() | |
# 如果点云或颜色数据维度过高,则展平 | |
if len(points.shape) > 2: | |
points = points.reshape(-1, 3) | |
if colors is not None and isinstance(colors, np.ndarray) and len(colors.shape) > 2: | |
colors = colors.reshape(-1, colors.shape[-1]) | |
# 验证点云形状 | |
if points.shape[1] != 3: | |
raise ValueError("`points` array must have shape [N, 3].") | |
# 处理颜色参数 | |
if colors is None: | |
colors = 'blue' | |
elif isinstance(colors, np.ndarray): | |
colors = np.asarray(colors) | |
if colors.shape[0] != points.shape[0]: | |
raise ValueError("Colors array length must match the number of points.") | |
if colors.shape[1] not in [3, 4]: | |
raise ValueError("Colors array must have shape [N, 3] or [N, 4].") | |
# 验证背景颜色参数 | |
if not isinstance(bg_color, tuple) or len(bg_color) != 3: | |
raise ValueError("Background color must be a tuple of (r, g, b) with values between 0 and 1.") | |
# 提取坐标 | |
x = points[:, 0] | |
y = points[:, 1] | |
z = points[:, 2] | |
# 创建图像,并设置自定义背景颜色 | |
fig = plt.figure(figsize=(8, 6), facecolor=bg_color) | |
ax = fig.add_subplot(111, projection='3d') | |
ax.set_facecolor(bg_color) | |
# 绘制散点图 | |
ax.scatter(x, y, z, c=colors, s=point_size, alpha=alpha) | |
# 设置等比例缩放 | |
max_range = np.array([x.max() - x.min(), | |
y.max() - y.min(), | |
z.max() - z.min()]).max() / 2.0 | |
mid_x = (x.max() + x.min()) * 0.5 | |
mid_y = (y.max() + y.min()) * 0.5 | |
mid_z = (z.max() + z.min()) * 0.5 | |
ax.set_xlim(mid_x - max_range, mid_x + max_range) | |
ax.set_ylim(mid_y - max_range, mid_y + max_range) | |
ax.set_zlim(mid_z - max_range, mid_z + max_range) | |
# 隐藏刻度和标签 | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_zticks([]) | |
ax.set_xlabel('') | |
ax.set_ylabel('') | |
ax.set_zlabel('') | |
ax.grid(False) | |
# 隐藏3D坐标轴的面板(pane)来去除轴的显示 | |
ax.xaxis.pane.set_visible(False) | |
ax.yaxis.pane.set_visible(False) | |
ax.zaxis.pane.set_visible(False) | |
# 设置标题(如果需要显示标题) | |
ax.set_title(title) | |
plt.tight_layout() | |
plt.show() | |
# def visualize_pointcloud( | |
# points, | |
# colors=None, | |
# title='Point Cloud', | |
# point_size=1, | |
# alpha=1.0 | |
# ): | |
# """ | |
# 可视化3D点云,同时支持每个点的RGB或RGBA颜色,并保证x, y, z三个轴等比例缩放。 | |
# 参数 | |
# ---------- | |
# points : np.ndarray 或 torch.Tensor | |
# 形状为 [N, 3] 的数组或张量,每行表示一个3D点 (x, y, z)。 | |
# colors : None, str, 或 np.ndarray | |
# - 如果为 None,则使用默认颜色 'blue'。 | |
# - 如果为字符串,则所有点均使用该颜色。 | |
# - 如果为数组,则形状应为 [N, 3] 或 [N, 4],表示每个点的颜色,值的范围应为 [0, 1](若为浮点数)。 | |
# title : str, 可选 | |
# 图像标题,默认 'Point Cloud'。 | |
# point_size : float, 可选 | |
# 点的大小,默认 1。 | |
# alpha : float, 可选 | |
# 点的整体透明度,默认 1.0。 | |
# 示例 | |
# -------- | |
# >>> import numpy as np | |
# >>> pts = np.random.rand(1000, 3) | |
# >>> cols = np.random.rand(1000, 3) | |
# >>> visualize_pointcloud(pts, colors=cols, title="随机点云") | |
# """ | |
# # 如果是 Torch 张量,则转换为 NumPy 数组 | |
# if isinstance(points, torch.Tensor): | |
# points = points.detach().cpu().numpy() | |
# if isinstance(colors, torch.Tensor): | |
# colors = colors.detach().cpu().numpy() | |
# # 如果点云或颜色数据维度过高,则展平 | |
# if len(points.shape) > 2: | |
# points = points.reshape(-1, 3) | |
# if colors is not None and isinstance(colors, np.ndarray) and len(colors.shape) > 2: | |
# colors = colors.reshape(-1, colors.shape[-1]) | |
# # 验证点云形状 | |
# if points.shape[1] != 3: | |
# raise ValueError("`points` array must have shape [N, 3].") | |
# # 处理颜色参数 | |
# if colors is None: | |
# colors = 'blue' | |
# elif isinstance(colors, np.ndarray): | |
# colors = np.asarray(colors) | |
# if colors.shape[0] != points.shape[0]: | |
# raise ValueError("Colors array length must match the number of points.") | |
# if colors.shape[1] not in [3, 4]: | |
# raise ValueError("Colors array must have shape [N, 3] or [N, 4].") | |
# # 提取坐标 | |
# x = points[:, 0] | |
# y = points[:, 1] | |
# z = points[:, 2] | |
# # 创建图像,并设置背景为白色 | |
# fig = plt.figure(figsize=(8, 6), facecolor='white') | |
# ax = fig.add_subplot(111, projection='3d') | |
# ax.set_facecolor('white') | |
# # 绘制散点图 | |
# ax.scatter(x, y, z, c=colors, s=point_size, alpha=alpha) | |
# # 设置等比例缩放 | |
# max_range = np.array([x.max() - x.min(), | |
# y.max() - y.min(), | |
# z.max() - z.min()]).max() / 2.0 | |
# mid_x = (x.max() + x.min()) * 0.5 | |
# mid_y = (y.max() + y.min()) * 0.5 | |
# mid_z = (z.max() + z.min()) * 0.5 | |
# ax.set_xlim(mid_x - max_range, mid_x + max_range) | |
# ax.set_ylim(mid_y - max_range, mid_y + max_range) | |
# ax.set_zlim(mid_z - max_range, mid_z + max_range) | |
# # 隐藏刻度和标签 | |
# ax.set_xticks([]) | |
# ax.set_yticks([]) | |
# ax.set_zticks([]) | |
# ax.set_xlabel('') | |
# ax.set_ylabel('') | |
# ax.set_zlabel('') | |
# ax.grid(False) | |
# # 隐藏3D坐标轴的面板(pane)来去除轴的显示 | |
# ax.xaxis.pane.set_visible(False) | |
# ax.yaxis.pane.set_visible(False) | |
# ax.zaxis.pane.set_visible(False) | |
# # 设置标题(如果需要显示标题) | |
# ax.set_title(title) | |
# plt.tight_layout() | |
# plt.show() | |
class Surfel: | |
def __init__(self, position, normal, radius=1.0, color=None): | |
""" | |
position: (x, y, z) | |
normal: (nx, ny, nz) | |
radius: scalar | |
color: (r, g, b) or None | |
""" | |
self.position = position | |
self.normal = normal | |
self.radius = radius | |
self.color = color | |
def __repr__(self): | |
return (f"Surfel(position={self.position}, " | |
f"normal={self.normal}, radius={self.radius}, " | |
f"color={self.color})") | |
class Octree: | |
def __init__(self, points, indices=None, bbox=None, max_points=10): | |
""" | |
构建八叉树: | |
- points: 所有点的 numpy 数组,形状为 (N, 3) | |
- indices: 当前节点中点的索引列表 | |
- bbox: 当前节点的包围盒,形式为 (center, half_size),其中半径为正方体半边长 | |
- max_points: 叶子节点允许的最大点数 | |
""" | |
self.points = points | |
if indices is None: | |
indices = np.arange(points.shape[0]) | |
self.indices = indices | |
# 如果没有给定包围盒,则计算所有点的包围盒,保证是一个正方体 | |
if bbox is None: | |
min_bound = points.min(axis=0) | |
max_bound = points.max(axis=0) | |
center = (min_bound + max_bound) / 2 | |
half_size = np.max(max_bound - min_bound) / 2 | |
bbox = (center, half_size) | |
self.center, self.half_size = bbox | |
self.children = [] # 存储子节点 | |
self.max_points = max_points | |
if len(self.indices) > self.max_points: | |
self.subdivide() | |
def subdivide(self): | |
"""将当前节点划分为8个子节点""" | |
cx, cy, cz = self.center | |
hs = self.half_size / 2 | |
# 八个象限的偏移量 | |
offsets = np.array([[dx, dy, dz] for dx in (-hs, hs) | |
for dy in (-hs, hs) | |
for dz in (-hs, hs)]) | |
for offset in offsets: | |
child_center = self.center + offset | |
child_indices = [] | |
# 检查每个点是否在子节点的包围盒内 | |
for idx in self.indices: | |
p = self.points[idx] | |
if np.all(np.abs(p - child_center) <= hs): | |
child_indices.append(idx) | |
child_indices = np.array(child_indices) | |
if len(child_indices) > 0: | |
child = Octree(self.points, indices=child_indices, bbox=(child_center, hs), max_points=self.max_points) | |
self.children.append(child) | |
# 划分后,内部节点不再直接保存点索引 | |
self.indices = None | |
def sphere_intersects_node(self, center, r): | |
""" | |
判断以center为球心, r为半径的球是否与当前节点的轴对齐包围盒相交。 | |
算法:计算球心到盒子的距离(只考虑超出盒子边界的部分),若小于r,则相交。 | |
""" | |
diff = np.abs(center - self.center) | |
max_diff = diff - self.half_size | |
max_diff = np.maximum(max_diff, 0) | |
dist_sq = np.sum(max_diff**2) | |
return dist_sq <= r*r | |
def query_ball_point(self, point, r): | |
""" | |
查询距离给定点 point 小于 r 的所有点索引。 | |
""" | |
results = [] | |
if not self.sphere_intersects_node(point, r): | |
return results | |
# 如果当前节点没有子节点,则为叶子节点 | |
if len(self.children) == 0: | |
if self.indices is not None: | |
for idx in self.indices: | |
if np.linalg.norm(self.points[idx] - point) <= r: | |
results.append(idx) | |
return results | |
else: | |
for child in self.children: | |
results.extend(child.query_ball_point(point, r)) | |
return results | |
def estimate_normal_from_pointmap(pointmap: torch.Tensor) -> torch.Tensor: | |
""" | |
Estimate surface normals from a 3D point map by computing cross products of | |
neighboring points, using PyTorch tensors. | |
Parameters | |
---------- | |
pointmap : torch.Tensor | |
A PyTorch tensor of shape [H, W, 3] containing 3D points in camera coordinates. | |
Each point is represented as (X, Y, Z). This tensor can be on CPU or GPU. | |
Returns | |
------- | |
torch.Tensor | |
A PyTorch tensor of shape [H, W, 3] containing estimated surface normals. | |
Each normal is a unit vector (X, Y, Z). | |
Points where normals cannot be computed (e.g. boundaries) will be zero vectors. | |
""" | |
# pointmap is shape (H, W, 3) | |
h, w = pointmap.shape[:2] | |
device = pointmap.device # Keep the device (CPU/GPU) consistent | |
dtype = pointmap.dtype | |
# Initialize the normal map | |
normal_map = torch.zeros((h, w, 3), device=device, dtype=dtype) | |
for y in range(h): | |
for x in range(w): | |
# Check if neighbors are within bounds | |
if x+1 >= w or y+1 >= h: | |
continue | |
p_center = pointmap[y, x] | |
p_right = pointmap[y, x+1] | |
p_down = pointmap[y+1, x] | |
# Compute vectors | |
v1 = p_right - p_center | |
v2 = p_down - p_center | |
v1 = v1 / torch.linalg.norm(v1) | |
v2 = v2 / torch.linalg.norm(v2) | |
# Cross product in camera coordinates | |
n_c = torch.cross(v1, v2) | |
# n_c *= 1e10 | |
# Compute norm of the normal vector | |
norm_len = torch.linalg.norm(n_c) | |
if norm_len < 1e-8: | |
continue | |
# Normalize and store | |
normal_map[y, x] = n_c / norm_len | |
return normal_map | |
def load_multiple_images(image_names, image_size=512, dtype=torch.float32): | |
images = load_images(image_names, size=image_size, force_1024=True, dtype=dtype) | |
img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # Just for reference | |
return images, img_ori | |
def load_initial_images(image_name): | |
images = load_images([image_name], size=512, force_1024=True) | |
img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # [H, W, 3], range [0,1] | |
if len(images) == 1: | |
images = [images[0], copy.deepcopy(images[0])] | |
images[1]['idx'] = 1 | |
return images, img_ori | |
def merge_surfels( | |
new_surfels: list, | |
current_timestamp: str, | |
existing_surfels: list, | |
existing_surfel_to_timestamp: dict, | |
position_threshold: float = 0.025, | |
normal_threshold: float = 0.7, | |
max_points_per_node: int = 10 # 八叉树叶子节点允许的最大点数 | |
): | |
""" | |
将新的 surfel 合并到已有 surfel 列表中,使用八叉树来加速空间查找。 | |
Args: | |
new_surfels (list[Surfel]): 待合并的新 surfel 列表。 | |
current_timestamp (str): 当前的时间戳。 | |
existing_surfels (list[Surfel]): 已存在的 surfel 列表。 | |
existing_surfel_to_timestamp (dict): 每个 surfel 索引到时间戳的映射。 | |
position_threshold (float): 判断两个 surfel 空间距离是否足够近的阈值。 | |
normal_threshold (float): 判断两个 surfel 法向是否对齐的阈值。 | |
max_points_per_node (int): 构建八叉树时,每个叶子节点最大允许的点数。 | |
Returns: | |
(list[Surfel], dict): | |
- 未能匹配的 surfel 列表,需要追加到已有 surfel 列表中。 | |
- 更新后的 existing_surfel_to_timestamp 映射。 | |
""" | |
# 安全检查 | |
assert len(existing_surfels) == len(existing_surfel_to_timestamp), ( | |
"existing_surfels 和 existing_surfel_to_timestamp 长度不匹配。" | |
) | |
# 构造已有 surfel 的位置和法向数组 | |
positions = np.array([s.position for s in existing_surfels]) # Shape: (N, 3) | |
normals = np.array([s.normal for s in existing_surfels]) # Shape: (N, 3) | |
# 用于存储未匹配到已有 surfel 的新 surfel | |
filtered_surfels = [] | |
merge_count = 0 | |
for new_surfel in new_surfels: | |
is_merged = False | |
for idx in range(len(positions)): | |
if np.linalg.norm(positions[idx] - new_surfel.position) < position_threshold: | |
if np.dot(normals[idx], new_surfel.normal) > normal_threshold: | |
existing_surfel_to_timestamp[idx].append(current_timestamp) | |
is_merged = True | |
merge_count += 1 | |
break | |
if not is_merged: | |
filtered_surfels.append(new_surfel) | |
# 返回未匹配的 surfel 列表及更新后的时间戳映射 | |
print(f"merge_count: {merge_count}") | |
return filtered_surfels, existing_surfel_to_timestamp | |
def pointmap_to_surfels(pointmap: torch.Tensor, | |
focal_lengths: torch.Tensor, | |
depth_map: torch.Tensor, | |
poses: torch.Tensor, # shape: (4, 4) | |
radius_scale: float = 0.5, | |
depth_threshold: float = 1.0, | |
estimate_normals: bool = True): | |
surfels = [] | |
if len(focal_lengths) == 2: | |
focal_lengths = torch.mean(focal_lengths, dim=0) | |
H, W = pointmap.shape[:2] | |
# 1) Estimate normals | |
if estimate_normals: | |
normal_map = estimate_normal_from_pointmap(pointmap) | |
else: | |
normal_map = torch.zeros_like(pointmap) | |
depth_remove_count = 0 | |
for v in range(H-1): | |
for u in range(W-1): | |
if depth_map[v, u] > depth_threshold: | |
depth_remove_count += 1 | |
continue | |
position = pointmap[v, u].detach().cpu().numpy() # in global coords | |
normal = normal_map[v, u].detach().cpu().numpy() # in global coords | |
depth = depth_map[v, u].detach().cpu().numpy() # in local coords | |
view_direction = position - poses[0:3, 3].detach().cpu().numpy() | |
view_direction = view_direction / np.linalg.norm(view_direction) | |
if np.dot(view_direction, normal) < 0: | |
normal = -normal | |
adjustment_value = 0.2 + 0.8 * np.abs(np.dot(view_direction, normal)) | |
radius = (radius_scale * depth/focal_lengths/adjustment_value).detach().cpu().numpy() | |
surfels.append(Surfel(position, normal, radius)) | |
print(f"depth_remove_count: {depth_remove_count}") | |
return surfels | |
def run_dust3r(input_images, | |
dust3r, | |
batch_size = 1, | |
niter = 1000, | |
lr = 0.01, | |
schedule = 'linear', | |
clean_pc = False, | |
focal_lengths = None, | |
poses = None, | |
device = 'cuda', | |
background_mask = None, | |
use_amp = False # <<< AMP CHANGE: add a flag to enable/disable AMP | |
): | |
# We wrap the entire inference and alignment in autocast so that | |
# forward passes and any internal backward passes happen in mixed precision. | |
with autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp): | |
pairs = make_pairs(input_images, scene_graph='complete', prefilter=None, symmetrize=True) | |
output = inference(pairs, dust3r, device, batch_size=batch_size) | |
mode = GlobalAlignerMode.PointCloudDifferentFocalOptimizer | |
scene = global_aligner(output, device=device, mode=mode) | |
if focal_lengths is not None: | |
scene.preset_focal(focal_lengths) | |
if poses is not None: | |
scene.preset_pose(poses) | |
if mode == GlobalAlignerMode.PointCloudDifferentFocalOptimizer: | |
# Depending on how dust3r internally does optimization, | |
# it may or may not require gradient scaling. | |
# If you need it, you can do something more manual with GradScaler. | |
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) | |
else: | |
loss = None | |
# If you want to clean up the pointcloud after alignment | |
if clean_pc: | |
scene = scene.clean_pointcloud() | |
return scene, loss | |
if __name__ == "__main__": | |
load_image_size = 512 | |
load_dtype = torch.float16 | |
device = 'cuda' | |
model_path = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" | |
selected_frame_paths = ["assets/jesus/jesus_0.jpg", | |
"assets/jesus/jesus_1.jpg", | |
"assets/jesus/jesus_2.jpg" | |
] | |
# pil_image = Image.open("./assets/radcliffe_camera_bg.png").resize((512, 288)) | |
# r, g, b, a = pil_image.split() | |
# background_mask = a | |
# background_mask = (1 - torch.tensor(np.array(background_mask))).unsqueeze(0).repeat(2, 1, 1).bool() | |
all_surfels = [] | |
surfel_to_timestamp = {} | |
dust3r = load_model(model_path, device=device) | |
dust3r.eval() | |
dust3r = dust3r.to(device) | |
dust3r = dust3r.half() | |
if len(selected_frame_paths) == 1: | |
selected_frame_paths = selected_frame_paths * 2 | |
frame_images, frame_img_ori = load_multiple_images(selected_frame_paths, | |
image_size=load_image_size, | |
dtype=load_dtype) | |
scene, loss = run_dust3r(frame_images, dust3r, device=device, use_amp=True) | |
# --- 1) Extract outputs --- | |
# pointcloud shape: [N, H, W, 3] | |
shrink_factor = 0.15 | |
pointcloud = torch.stack(scene.get_pts3d()) | |
# poses shape: [N, 4, 4] | |
# optimized_poses = scene.get_im_poses() | |
# focal_lengths shape: [N] | |
focal_lengths = scene.get_focals() | |
# adjustion_transformation_matrix = SpatialConstructor.estimate_pose_alignment(optimized_poses, original_camera_poses) # optimized_poses -> original_camera_poses matrix | |
# adjusted_optimized_poses = adjustion_transformation_matrix @ optimized_poses | |
# --- 2) Resize pointcloud --- | |
# Permute for resizing -> [N, 3, H, W] | |
pointcloud = pointcloud.permute(0, 3, 1, 2) | |
# Resize using bilinear interpolation | |
pointcloud = F.interpolate( | |
pointcloud, | |
scale_factor=shrink_factor, | |
mode='bilinear' | |
) | |
# Permute back -> [N, H', W', 3] | |
pointcloud = pointcloud.permute(0, 2, 3, 1)[-1:] | |
# transform pointcloud | |
# pointcloud = torch.stack([SpatialConstructor.transform_pointmap(pointcloud[i], adjustion_transformation_matrix) for i in range(pointcloud.shape[0])]) | |
rgbs = scene.imgs | |
rgbs = torch.tensor(np.array(rgbs)) | |
rgbs = rgbs.permute(0, 3, 1, 2) | |
rgbs = F.interpolate(rgbs, scale_factor=shrink_factor, mode='bilinear') | |
rgbs = rgbs.permute(0, 2, 3, 1)[-1:] | |
visualize_pointcloud(pointcloud, rgbs, point_size=4) | |
# --- 3) Resize depth map --- | |
# depth_map shape: [N, H, W] | |
depth_map = torch.stack(scene.get_depthmaps()) | |
# Add channel dimension -> [N, 1, H, W] | |
depth_map = depth_map.unsqueeze(1) | |
depth_map = F.interpolate( | |
depth_map, | |
scale_factor=shrink_factor, | |
mode='bilinear' | |
) | |
poses = scene.get_im_poses()[-1:] | |
# Remove channel dimension -> [N, H', W'] | |
depth_map = depth_map.squeeze(1)[-1:] | |
for frame_idx in range(len(pointcloud)): | |
# if frame_idx > 1: | |
# break | |
# Create surfels for the current frame | |
surfels = pointmap_to_surfels( | |
pointmap=pointcloud[frame_idx], | |
focal_lengths=focal_lengths[frame_idx] * shrink_factor, | |
depth_map=depth_map[frame_idx], | |
poses=poses[frame_idx], | |
estimate_normals=True, | |
radius_scale=0.5, | |
depth_threshold=0.48 | |
) | |
# Merge with existing surfels if not the first frame | |
if frame_idx > 0: | |
surfels, surfel_to_timestamp = merge_surfels( | |
new_surfels=surfels, | |
current_timestamp=frame_idx, | |
existing_surfels=all_surfels, | |
existing_surfel_to_timestamp=surfel_to_timestamp, | |
position_threshold=0.01, | |
normal_threshold=0.7 | |
) | |
# Update timestamp mapping | |
num_surfels = len(surfels) | |
surfel_start_index = len(all_surfels) | |
for surfel_index in range(num_surfels): | |
# Each newly created surfel gets mapped to this frame index | |
# surfel_to_timestamp[surfel_start_index + surfel_index] = [frame_idx] | |
surfel_to_timestamp[surfel_start_index + surfel_index] = [2] | |
all_surfels.extend(surfels) | |
positions = np.array([s.position for s in all_surfels], dtype=np.float32) | |
normals = np.array([s.normal for s in all_surfels], dtype=np.float32) | |
radii = np.array([s.radius for s in all_surfels], dtype=np.float32) | |
colors = np.array([s.color for s in all_surfels], dtype=np.float32) | |
visualize_surfels(all_surfels) | |
# np.savez(f"./surfels_added_first2.npz", | |
# positions=positions, | |
# normals=normals, | |
# radii=radii, | |
# colors=colors) | |
# with open("surfel_to_timestamp_first2.json", "w") as f: | |
# json.dump(surfel_to_timestamp, f) | |
np.savez(f"./surfels_added_only3.npz", | |
positions=positions, | |
normals=normals, | |
radii=radii, | |
colors=colors) | |
with open("surfel_to_timestamp_only3.json", "w") as f: | |
json.dump(surfel_to_timestamp, f) | |
stop = 1 |