Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| def get_position_map_from_depth(depth, mask, intrinsics, extrinsics, image_wh=None): | |
| """Compute the position map from the depth map and the camera parameters for a batch of views. | |
| Args: | |
| depth (torch.Tensor): The depth maps with the shape (B, H, W, 1). | |
| mask (torch.Tensor): The masks with the shape (B, H, W, 1). | |
| intrinsics (torch.Tensor): The camera intrinsics matrices with the shape (B, 3, 3). | |
| extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4). | |
| image_wh (Tuple[int, int]): The image width and height. | |
| Returns: | |
| torch.Tensor: The position maps with the shape (B, H, W, 3). | |
| """ | |
| if image_wh is None: | |
| image_wh = depth.shape[2], depth.shape[1] | |
| B, H, W, _ = depth.shape | |
| depth = depth.squeeze(-1) | |
| u_coord, v_coord = torch.meshgrid( | |
| torch.arange(image_wh[0]), torch.arange(image_wh[1]), indexing="xy" | |
| ) | |
| u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) | |
| v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) | |
| # Compute the position map by back-projecting depth pixels to 3D space | |
| x = ( | |
| (u_coord - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) | |
| * depth | |
| / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) | |
| ) | |
| y = ( | |
| (v_coord - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) | |
| * depth | |
| / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) | |
| ) | |
| z = depth | |
| # Concatenate to form the 3D coordinates in the camera frame | |
| camera_coords = torch.stack([x, y, z], dim=-1) | |
| # Apply the extrinsic matrix to get coordinates in the world frame | |
| coords_homogeneous = torch.nn.functional.pad( | |
| camera_coords, (0, 1), "constant", 1.0 | |
| ) # Add a homogeneous coordinate | |
| world_coords = torch.matmul( | |
| coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2) | |
| ).view(B, H, W, 4) | |
| # Apply the mask to the position map | |
| position_map = world_coords[..., :3] * mask | |
| return position_map | |
| def get_position_map_from_depth_ortho( | |
| depth, mask, extrinsics, ortho_scale, image_wh=None | |
| ): | |
| """Compute the position map from the depth map and the camera parameters for a batch of views | |
| using orthographic projection with a given ortho_scale. | |
| Args: | |
| depth (torch.Tensor): The depth maps with the shape (B, H, W, 1). | |
| mask (torch.Tensor): The masks with the shape (B, H, W, 1). | |
| extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4). | |
| ortho_scale (torch.Tensor): The scaling factor for the orthographic projection with the shape (B, 1, 1, 1). | |
| image_wh (Tuple[int, int]): Optional. The image width and height. | |
| Returns: | |
| torch.Tensor: The position maps with the shape (B, H, W, 3). | |
| """ | |
| if image_wh is None: | |
| image_wh = depth.shape[2], depth.shape[1] | |
| B, H, W, _ = depth.shape | |
| depth = depth.squeeze(-1) | |
| # Generating grid of coordinates in the image space | |
| u_coord, v_coord = torch.meshgrid( | |
| torch.arange(0, image_wh[0]), torch.arange(0, image_wh[1]), indexing="xy" | |
| ) | |
| u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) | |
| v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) | |
| # Compute the position map using orthographic projection with ortho_scale | |
| x = (u_coord - image_wh[0] / 2) * ortho_scale / image_wh[0] | |
| y = (v_coord - image_wh[1] / 2) * ortho_scale / image_wh[1] | |
| z = depth | |
| # Concatenate to form the 3D coordinates in the camera frame | |
| camera_coords = torch.stack([x, y, z], dim=-1) | |
| # Apply the extrinsic matrix to get coordinates in the world frame | |
| coords_homogeneous = torch.nn.functional.pad( | |
| camera_coords, (0, 1), "constant", 1.0 | |
| ) # Add a homogeneous coordinate | |
| world_coords = torch.matmul( | |
| coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2) | |
| ).view(B, H, W, 4) | |
| # Apply the mask to the position map | |
| position_map = world_coords[..., :3] * mask | |
| return position_map | |
| def get_opencv_from_blender(matrix_world, fov=None, image_size=None): | |
| # convert matrix_world to opencv format extrinsics | |
| opencv_world_to_cam = matrix_world.inverse() | |
| opencv_world_to_cam[1, :] *= -1 | |
| opencv_world_to_cam[2, :] *= -1 | |
| R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3] | |
| if fov is None: # orthographic camera | |
| return R, T | |
| R, T = R.unsqueeze(0), T.unsqueeze(0) | |
| # convert fov to opencv format intrinsics | |
| focal = 1 / np.tan(fov / 2) | |
| intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32) | |
| opencv_cam_matrix = ( | |
| torch.from_numpy(intrinsics).unsqueeze(0).float().to(matrix_world.device) | |
| ) | |
| opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2]).to( | |
| matrix_world.device | |
| ) | |
| opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2 | |
| return R, T, opencv_cam_matrix | |
| def get_ray_directions( | |
| H: int, | |
| W: int, | |
| focal: float, | |
| principal: Optional[Tuple[float, float]] = None, | |
| use_pixel_centers: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| Get ray directions for all pixels in camera coordinate. | |
| Args: | |
| H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers | |
| Outputs: | |
| directions: (H, W, 3), the direction of the rays in camera coordinate | |
| """ | |
| pixel_center = 0.5 if use_pixel_centers else 0 | |
| cx, cy = W / 2, H / 2 if principal is None else principal | |
| i, j = torch.meshgrid( | |
| torch.arange(W, dtype=torch.float32) + pixel_center, | |
| torch.arange(H, dtype=torch.float32) + pixel_center, | |
| indexing="xy", | |
| ) | |
| directions = torch.stack( | |
| [(i - cx) / focal, -(j - cy) / focal, -torch.ones_like(i)], -1 | |
| ) | |
| return F.normalize(directions, dim=-1) | |
| def get_rays( | |
| directions: torch.Tensor, c2w: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Get ray origins and directions from camera coordinates to world coordinates | |
| Args: | |
| directions: (H, W, 3) ray directions in camera coordinates | |
| c2w: (4, 4) camera-to-world transformation matrix | |
| Outputs: | |
| rays_o, rays_d: (H, W, 3) ray origins and directions in world coordinates | |
| """ | |
| # Rotate ray directions from camera coordinate to the world coordinate | |
| rays_d = directions @ c2w[:3, :3].T | |
| rays_o = c2w[:3, 3].expand(rays_d.shape) | |
| return rays_o, rays_d | |
| def compute_plucker_embed( | |
| c2w: torch.Tensor, image_width: int, image_height: int, focal: float | |
| ) -> torch.Tensor: | |
| """ | |
| Computes Plucker coordinates for a camera. | |
| Args: | |
| c2w: (4, 4) camera-to-world transformation matrix | |
| image_width: Image width | |
| image_height: Image height | |
| focal: Focal length of the camera | |
| Returns: | |
| plucker: (6, H, W) Plucker embedding | |
| """ | |
| directions = get_ray_directions(image_height, image_width, focal) | |
| rays_o, rays_d = get_rays(directions, c2w) | |
| # Cross product to get Plucker coordinates | |
| cross = torch.cross(rays_o, rays_d, dim=-1) | |
| plucker = torch.cat((rays_d, cross), dim=-1) | |
| return plucker.permute(2, 0, 1) | |
| def get_plucker_embeds_from_cameras( | |
| c2w: List[torch.Tensor], fov: List[float], image_size: int | |
| ) -> torch.Tensor: | |
| """ | |
| Given lists of camera transformations and fov, returns the batched plucker embeddings. | |
| Args: | |
| c2w: list of camera-to-world transformation matrices | |
| fov: list of field of view values | |
| image_size: size of the image | |
| Returns: | |
| plucker_embeds: (B, 6, H, W) batched plucker embeddings | |
| """ | |
| plucker_embeds = [] | |
| for cam_matrix, cam_fov in zip(c2w, fov): | |
| focal = 0.5 * image_size / np.tan(0.5 * cam_fov) | |
| plucker = compute_plucker_embed(cam_matrix, image_size, image_size, focal) | |
| plucker_embeds.append(plucker) | |
| return torch.stack(plucker_embeds) | |
| def get_plucker_embeds_from_cameras_ortho( | |
| c2w: List[torch.Tensor], ortho_scale: List[float], image_size: int | |
| ): | |
| """ | |
| Given lists of camera transformations and fov, returns the batched plucker embeddings. | |
| Parameters: | |
| c2w: list of camera-to-world transformation matrices | |
| fov: list of field of view values | |
| image_size: size of the image | |
| Returns: | |
| plucker_embeds: plucker embeddings (B, 6, H, W) | |
| """ | |
| plucker_embeds = [] | |
| # compute pairwise mask and plucker embeddings | |
| for cam_matrix, scale in zip(c2w, ortho_scale): | |
| # blender to opencv to pytorch3d | |
| R, T = get_opencv_from_blender(cam_matrix) | |
| cam_pos = -R.T @ T | |
| view_dir = R.T @ torch.tensor([0, 0, 1]).float().to(cam_matrix.device) | |
| # normalize camera position | |
| cam_pos = F.normalize(cam_pos, dim=0) | |
| plucker = torch.concat([view_dir, cam_pos]) | |
| plucker = plucker.unsqueeze(-1).unsqueeze(-1).repeat(1, image_size, image_size) | |
| plucker_embeds.append(plucker) | |
| plucker_embeds = torch.stack(plucker_embeds) | |
| return plucker_embeds | |