vmem / extern /CUT3R /viser_utils.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import torch
import os
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib as mpl
import cv2
import numpy as np
import matplotlib.cm as cm
import viser
import viser.transforms as tf
import time
import trimesh
import dataclasses
from scipy.spatial.transform import Rotation
from src.dust3r.viz import (
add_scene_cam,
CAM_COLORS,
OPENGL,
pts3d_to_trimesh,
cat_meshes,
)
def todevice(batch, device, callback=None, non_blocking=False):
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
batch: list, tuple, dict of tensors or other things
device: pytorch device or 'numpy'
callback: function that would be called on every sub-elements.
"""
if callback:
batch = callback(batch)
if isinstance(batch, dict):
return {k: todevice(v, device) for k, v in batch.items()}
if isinstance(batch, (tuple, list)):
return type(batch)(todevice(x, device) for x in batch)
x = batch
if device == "numpy":
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif x is not None:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if torch.is_tensor(x):
x = x.to(device, non_blocking=non_blocking)
return x
to_device = todevice # alias
def to_numpy(x):
return todevice(x, "numpy")
def segment_sky(image):
import cv2
from scipy import ndimage
# Convert to HSV
image = to_numpy(image)
if np.issubdtype(image.dtype, np.floating):
image = np.uint8(255 * image.clip(min=0, max=1))
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Define range for blue color and create mask
lower_blue = np.array([0, 0, 100])
upper_blue = np.array([30, 255, 255])
mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
# add luminous gray
mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
# Morphological operations
kernel = np.ones((5, 5), np.uint8)
mask2 = ndimage.binary_opening(mask, structure=kernel)
# keep only largest CC
_, labels, stats, _ = cv2.connectedComponentsWithStats(
mask2.view(np.uint8), connectivity=8
)
cc_sizes = stats[1:, cv2.CC_STAT_AREA]
order = cc_sizes.argsort()[::-1] # bigger first
i = 0
selection = []
while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
selection.append(1 + order[i])
i += 1
mask3 = np.in1d(labels, selection).reshape(labels.shape)
# Apply mask
return torch.from_numpy(mask3)
def convert_scene_output_to_glb(
outdir,
imgs,
pts3d,
mask,
focals,
cams2world,
cam_size=0.05,
show_cam=True,
cam_color=None,
as_pointcloud=False,
transparent_cams=False,
silent=False,
save_name=None,
):
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
cams2world = to_numpy(cams2world)
scene = trimesh.Scene()
# full pointcloud
if as_pointcloud:
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
scene.add_geometry(pct)
else:
meshes = []
for i in range(len(imgs)):
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
mesh = trimesh.Trimesh(**cat_meshes(meshes))
scene.add_geometry(mesh)
# add each camera
if show_cam:
for i, pose_c2w in enumerate(cams2world):
if isinstance(cam_color, list):
camera_edge_color = cam_color[i]
else:
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
add_scene_cam(
scene,
pose_c2w,
camera_edge_color,
None if transparent_cams else imgs[i],
focals[i],
imsize=imgs[i].shape[1::-1],
screen_width=cam_size,
)
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler("y", np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
if save_name is None:
save_name = "scene"
outfile = os.path.join(outdir, save_name + ".glb")
if not silent:
print("(exporting 3D scene to", outfile, ")")
scene.export(file_obj=outfile)
return outfile
@dataclasses.dataclass
class CameraState(object):
fov: float
aspect: float
c2w: np.ndarray
def get_K(self, img_wh):
W, H = img_wh
focal_length = H / 2.0 / np.tan(self.fov / 2.0)
K = np.array(
[
[focal_length, 0.0, W / 2.0],
[0.0, focal_length, H / 2.0],
[0.0, 0.0, 1.0],
]
)
return K
def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2):
"""
:param w: pixels
:param h: pixels
:param vmin: min value
:param vmax: max value
:param cmap_name:
:param label
:return:
"""
fig = Figure(figsize=(2, 8), dpi=100)
fig.subplots_adjust(right=1.5)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(111)
cmap = cm.get_cmap(cmap_name)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
tick_cnt = 6
tick_loc = np.linspace(vmin, vmax, tick_cnt)
cb1 = mpl.colorbar.ColorbarBase(
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
)
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
if cbar_precision == 0:
tick_label = [x[:-2] for x in tick_label]
cb1.set_ticklabels(tick_label)
cb1.ax.tick_params(labelsize=18, rotation=0)
if label is not None:
cb1.set_label(label)
canvas.draw()
s, (width, height) = canvas.print_to_buffer()
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
im = im[:, :, :3].astype(np.float32) / 255.0
if h != im.shape[0]:
w = int(im.shape[1] / im.shape[0] * h)
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
return im
def colorize_np(
x,
cmap_name="jet",
mask=None,
range=None,
append_cbar=False,
cbar_in_image=False,
cbar_precision=2,
):
"""
turn a grayscale image into a color image
:param x: input grayscale, [H, W]
:param cmap_name: the colorization method
:param mask: the mask image, [H, W]
:param range: the range for scaling, automatic if None, [min, max]
:param append_cbar: if append the color bar
:param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image
:return: colorized image, [H, W]
"""
if range is not None:
vmin, vmax = range
elif mask is not None:
vmin = np.min(x[mask][np.nonzero(x[mask])])
vmax = np.max(x[mask])
x[np.logical_not(mask)] = vmin
else:
vmin, vmax = np.percentile(x, (1, 100))
vmax += 1e-6
x = np.clip(x, vmin, vmax)
x = (x - vmin) / (vmax - vmin)
cmap = cm.get_cmap(cmap_name)
x_new = cmap(x)[:, :, :3]
if mask is not None:
mask = np.float32(mask[:, :, np.newaxis])
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
cbar = get_vertical_colorbar(
h=x.shape[0],
vmin=vmin,
vmax=vmax,
cmap_name=cmap_name,
cbar_precision=cbar_precision,
)
if append_cbar:
if cbar_in_image:
x_new[:, -cbar.shape[1] :, :] = cbar
else:
x_new = np.concatenate(
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
)
return x_new
else:
return x_new
def colorize(
x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False
):
"""
turn a grayscale image into a color image
:param x: torch.Tensor, grayscale image, [H, W] or [B, H, W]
:param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None
"""
device = x.device
x = x.cpu().numpy()
if mask is not None:
mask = mask.cpu().numpy() > 0.99
kernel = np.ones((3, 3), np.uint8)
if x.ndim == 2:
x = x[None]
if mask is not None:
mask = mask[None]
out = []
for x_ in x:
if mask is not None:
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
out.append(torch.from_numpy(x_).to(device).float())
out = torch.stack(out).squeeze(0)
return out
class PointCloudViewer:
def __init__(
self,
model,
state_args,
pc_list,
color_list,
conf_list,
cam_dict,
image_mask=None,
edge_color_list=None,
device="cpu",
port=8080,
show_camera=True,
vis_threshold=1,
size=512
):
self.model = model
self.size=size
self.state_args = state_args
self.server = viser.ViserServer(port=port)
self.server.set_up_direction("-y")
self.device = device
self.conf_list = conf_list
self.vis_threshold = vis_threshold
self.tt = lambda x: torch.from_numpy(x).float().to(device)
self.pcs, self.all_steps = self.read_data(
pc_list, color_list, conf_list, edge_color_list
)
self.cam_dict = cam_dict
self.num_frames = len(self.all_steps)
self.image_mask = image_mask
self.show_camera = show_camera
self.on_replay = False
self.vis_pts_list = []
self.traj_list = []
self.orig_img_list = [x[0] for x in color_list]
self.via_points = []
gui_reset_up = self.server.gui.add_button(
"Reset up direction",
hint="Set the camera control 'up' direction to the current camera's 'up'.",
)
@gui_reset_up.on_click
def _(event: viser.GuiEvent) -> None:
client = event.client
assert client is not None
client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array(
[0.0, -1.0, 0.0]
)
button3 = self.server.gui.add_button("4D (Only Show Current Frame)")
button4 = self.server.gui.add_button("3D (Show All Frames)")
self.is_render = False
self.fourd = False
@button3.on_click
def _(event: viser.GuiEvent) -> None:
self.fourd = True
@button4.on_click
def _(event: viser.GuiEvent) -> None:
self.fourd = False
self.focal_slider = self.server.add_gui_slider(
"Focal Length",
min=0.1,
max=99999,
step=1,
initial_value=533,
)
self.psize_slider = self.server.add_gui_slider(
"Point Size",
min=0.0001,
max=0.1,
step=0.0001,
initial_value=0.0005,
)
self.camsize_slider = self.server.add_gui_slider(
"Camera Size",
min=0.01,
max=0.5,
step=0.01,
initial_value=0.1,
)
self.pc_handles = []
self.cam_handles = []
@self.psize_slider.on_update
def _(_) -> None:
for handle in self.pc_handles:
handle.point_size = self.psize_slider.value
@self.camsize_slider.on_update
def _(_) -> None:
for handle in self.cam_handles:
handle.scale = self.camsize_slider.value
handle.line_thickness = 0.03 * handle.scale
self.server.on_client_connect(self._connect_client)
def get_camera_state(self, client: viser.ClientHandle) -> CameraState:
camera = client.camera
c2w = np.concatenate(
[
np.concatenate(
[tf.SO3(camera.wxyz).as_matrix(), camera.position[:, None]], 1
),
[[0, 0, 0, 1]],
],
0,
)
return CameraState(
fov=camera.fov,
aspect=camera.aspect,
c2w=c2w,
)
@staticmethod
def generate_pseudo_intrinsics(h, w):
focal = (h**2 + w**2) ** 0.5
return np.array([[focal, 0, w // 2], [0, focal, h // 2], [0, 0, 1]]).astype(
np.float32
)
def get_ray_map(self, c2w, h, w, intrinsics=None):
if intrinsics is None:
intrinsics = self.generate_pseudo_intrinsics(h, w)
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
grid = np.stack([i, j, np.ones_like(i)], axis=-1)
ro = c2w[:3, 3]
rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T
rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3)
rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True)
ro = np.broadcast_to(ro, (h, w, 3))
ray_map = np.concatenate([ro, rd], axis=-1)
return ray_map
def set_camera_loc(camera, pose, K):
"""
pose: 4x4 matrix
K: 3x3 matrix
"""
fx, fy = K[0, 0], K[1, 1]
cx, cy = K[0, 2], K[1, 2]
aspect = float(cx) / float(cy)
fov = 2 * np.arctan(2 * cx / fx)
wxyz_xyz = tf.SE3.from_matrix(pose).wxyz_xyz
wxyz = wxyz_xyz[:4]
xyz = wxyz_xyz[4:]
camera.wxyz = wxyz
camera.position = xyz
camera.fov = fov
camera.aspect = aspect
def _connect_client(self, client: viser.ClientHandle):
from src.dust3r.inference import inference_step
from src.dust3r.utils.geometry import geotrf
wxyz_panel = client.gui.add_text("wxyz:", f"{client.camera.wxyz}")
position_panel = client.gui.add_text("position:", f"{client.camera.position}")
fov_panel = client.gui.add_text(
"fov:", f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}"
)
aspect_panel = client.gui.add_text("aspect:", "1.0")
@client.camera.on_update
def _(_: viser.CameraHandle):
with self.server.atomic():
wxyz_panel.value = f"{client.camera.wxyz}"
position_panel.value = f"{client.camera.position}"
fov_panel.value = (
f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}"
)
aspect_panel.value = "1.0"
gui_set_current_camera = client.gui.add_button(
"Set Current Camera to Infer Raymap"
)
@gui_set_current_camera.on_click
def _(_) -> None:
try:
cam = self.get_camera_state(client)
cam.fov = 2 * np.arctan(self.size / self.focal_slider.value)
cam.aspect = (512 / 384) if self.size==512 else 1.0
pose = cam.c2w
if self.size == 512:
intrins = self.generate_pseudo_intrinsics(384, 512)
raymap = torch.from_numpy(self.get_ray_map(pose, 384, 512, intrins))[
None
].float()
else:
intrins = self.generate_pseudo_intrinsics(224, 224)
raymap = torch.from_numpy(self.get_ray_map(pose, 224, 224, intrins))[
None
].float()
view = {
"img": torch.full((1, 3, 384, 512), torch.nan) if self.size==512 else torch.full((1, 3, 224, 224), torch.nan),
"ray_map": raymap,
"true_shape": torch.from_numpy(np.int32([raymap.shape[1:-1]])),
"idx": self.num_frames + 1,
"instance": str(self.num_frames + 1),
"camera_pose": torch.from_numpy(np.eye(4).astype(np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(False).unsqueeze(0),
"ray_mask": torch.tensor(True).unsqueeze(0),
"update": torch.tensor(False).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
print("Start Inference Raymap")
output = inference_step(
view, self.state_args[-1], self.model, device=self.device
)
print("Finish Inference Raymap")
pts3ds = output["pred"]["pts3d_in_self_view"].cpu().numpy()
pts3ds = geotrf(pose[None], pts3ds)
colors = 0.5 * (output["pred"]["rgb"].cpu().numpy() + 1.0)
depthmap = output["pred"]["pts3d_in_self_view"].cpu().numpy()[0][..., -1]
conf = output["pred"]["conf"].cpu().numpy()
disp = 1.0 / depthmap
pts3ds, colors = self.parse_pc_data(pts3ds, colors, set_border_color=True)
mask = (conf > 1.0).reshape(-1)
self.num_frames += 1
self.pc_handles.append(
self.server.add_point_cloud(
name=f"/frames/{self.num_frames-1}/pred_pts",
points=pts3ds[mask],
colors=colors[mask],
point_size=0.005,
)
)
self.server.add_camera_frustum(
name=f"/frames/{self.num_frames-1}/camera",
fov=cam.fov,
aspect=cam.aspect,
wxyz=client.camera.wxyz,
position=client.camera.position,
scale=0.1,
color=[64, 179, 230],
)
print("Adding new pointcloud: ", pts3ds.shape)
except Exception as e:
print(e)
@staticmethod
def set_color_border(image, border_width=5, color=[1, 0, 0]):
image[:border_width, :, 0] = color[0] # Red channel
image[:border_width, :, 1] = color[1] # Green channel
image[:border_width, :, 2] = color[2] # Blue channel
image[-border_width:, :, 0] = color[0]
image[-border_width:, :, 1] = color[1]
image[-border_width:, :, 2] = color[2]
image[:, :border_width, 0] = color[0]
image[:, :border_width, 1] = color[1]
image[:, :border_width, 2] = color[2]
image[:, -border_width:, 0] = color[0]
image[:, -border_width:, 1] = color[1]
image[:, -border_width:, 2] = color[2]
return image
def read_data(self, pc_list, color_list, conf_list, edge_color_list=None):
pcs = {}
step_list = []
for i, pc in enumerate(pc_list):
step = i
pcs.update(
{
step: {
"pc": pc,
"color": color_list[i],
"conf": conf_list[i],
"edge_color": (
None if edge_color_list[i] is None else edge_color_list[i]
),
}
}
)
step_list.append(step)
normalized_indices = (
np.array(list(range(len(pc_list))))
/ np.array(list(range(len(pc_list)))).max()
)
cmap = cm.viridis
self.camera_colors = cmap(normalized_indices)
return pcs, step_list
def parse_pc_data(
self,
pc,
color,
conf=None,
edge_color=[0.251, 0.702, 0.902],
set_border_color=False,
):
pred_pts = pc.reshape(-1, 3) # [N, 3]
if set_border_color and edge_color is not None:
color = self.set_color_border(color[0], color=edge_color)
if np.isnan(color).any():
color = np.zeros((pred_pts.shape[0], 3))
color[:, 2] = 1
else:
color = color.reshape(-1, 3)
if conf is not None:
conf = conf[0].reshape(-1)
pred_pts = pred_pts[conf > self.vis_threshold]
color = color[conf > self.vis_threshold]
return pred_pts, color
def add_pc(self, step):
pc = self.pcs[step]["pc"]
color = self.pcs[step]["color"]
conf = self.pcs[step]["conf"]
edge_color = self.pcs[step].get("edge_color", None)
pred_pts, color = self.parse_pc_data(
pc, color, conf, edge_color, set_border_color=True
)
self.vis_pts_list.append(pred_pts)
self.pc_handles.append(
self.server.add_point_cloud(
name=f"/frames/{step}/pred_pts",
points=pred_pts,
colors=color,
point_size=0.0005,
)
)
def add_camera(self, step):
cam = self.cam_dict
focal = cam["focal"][step]
pp = cam["pp"][step]
R = cam["R"][step]
t = cam["t"][step]
q = tf.SO3.from_matrix(R).wxyz
fov = 2 * np.arctan(pp[0] / focal)
aspect = pp[0] / pp[1]
self.traj_list.append((q, t))
self.cam_handles.append(
self.server.add_camera_frustum(
name=f"/frames/{step}/camera",
fov=fov,
aspect=aspect,
wxyz=q,
position=t,
scale=0.1,
color=(50, 205, 50),
)
)
def animate(self):
with self.server.add_gui_folder("Playback"):
gui_timestep = self.server.add_gui_slider(
"Train Step",
min=0,
max=self.num_frames - 1,
step=1,
initial_value=0,
disabled=False,
)
gui_next_frame = self.server.add_gui_button("Next Step", disabled=False)
gui_prev_frame = self.server.add_gui_button("Prev Step", disabled=False)
gui_playing = self.server.add_gui_checkbox("Playing", False)
gui_framerate = self.server.add_gui_slider(
"FPS", min=1, max=60, step=0.1, initial_value=1
)
gui_framerate_options = self.server.add_gui_button_group(
"FPS options", ("10", "20", "30", "60")
)
@gui_next_frame.on_click
def _(_) -> None:
gui_timestep.value = (gui_timestep.value + 1) % self.num_frames
@gui_prev_frame.on_click
def _(_) -> None:
gui_timestep.value = (gui_timestep.value - 1) % self.num_frames
@gui_playing.on_update
def _(_) -> None:
gui_timestep.disabled = gui_playing.value
gui_next_frame.disabled = gui_playing.value
gui_prev_frame.disabled = gui_playing.value
@gui_framerate_options.on_click
def _(_) -> None:
gui_framerate.value = int(gui_framerate_options.value)
prev_timestep = gui_timestep.value
@gui_timestep.on_update
def _(_) -> None:
nonlocal prev_timestep
current_timestep = gui_timestep.value
with self.server.atomic():
self.frame_nodes[current_timestep].visible = True
self.frame_nodes[prev_timestep].visible = False
prev_timestep = current_timestep
self.server.flush() # Optional!
self.server.add_frame(
"/frames",
show_axes=False,
)
self.frame_nodes = []
for i in range(self.num_frames):
step = self.all_steps[i]
self.frame_nodes.append(
self.server.add_frame(
f"/frames/{step}",
show_axes=False,
)
)
self.add_pc(step)
if self.show_camera:
self.add_camera(step)
prev_timestep = gui_timestep.value
while True:
if self.on_replay:
pass
else:
if gui_playing.value:
gui_timestep.value = (gui_timestep.value + 1) % self.num_frames
for i, frame_node in enumerate(self.frame_nodes):
frame_node.visible = (
i <= gui_timestep.value
if not self.fourd
else i == gui_timestep.value
)
time.sleep(1.0 / gui_framerate.value)
def run(self):
self.animate()
while True:
time.sleep(10.0)