Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import cv2 | |
| import random | |
| import numpy as np | |
| from torchvision import transforms | |
| from pytorch3d.renderer import TexturesUV | |
| from pytorch3d.ops import interpolate_face_attributes | |
| from PIL import Image | |
| from tqdm import tqdm | |
| # customized | |
| import sys | |
| sys.path.append(".") | |
| from lib.camera_helper import init_camera | |
| from lib.render_helper import init_renderer, render | |
| from lib.shading_helper import ( | |
| BlendParams, | |
| init_soft_phong_shader, | |
| init_flat_texel_shader, | |
| ) | |
| from lib.vis_helper import visualize_outputs, visualize_quad_mask | |
| from lib.constants import * | |
| def get_all_4_locations(values_y, values_x): | |
| y_0 = torch.floor(values_y) | |
| y_1 = torch.ceil(values_y) | |
| x_0 = torch.floor(values_x) | |
| x_1 = torch.ceil(values_x) | |
| return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long() | |
| def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device): | |
| """ | |
| compose quad mask: | |
| -> 0: background | |
| -> 1: old | |
| -> 2: update | |
| -> 3: new | |
| """ | |
| new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device) | |
| update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device) | |
| old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device) | |
| all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor | |
| quad_mask_tensor = torch.zeros_like(all_mask_tensor) | |
| quad_mask_tensor[old_mask_tensor == 1] = 1 | |
| quad_mask_tensor[update_mask_tensor == 1] = 2 | |
| quad_mask_tensor[new_mask_tensor == 1] = 3 | |
| return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor | |
| def compute_view_heat(similarity_tensor, quad_mask_tensor): | |
| num_total_pixels = quad_mask_tensor.reshape(-1).shape[0] | |
| heat = 0 | |
| for idx in QUAD_WEIGHTS: | |
| heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels | |
| return heat | |
| def select_viewpoint(selected_view_ids, view_punishments, | |
| mode, dist_list, elev_list, azim_list, sector_list, view_idx, | |
| similarity_texture_cache, exist_texture, | |
| mesh, faces, verts_uvs, | |
| image_size, faces_per_pixel, | |
| init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
| device, use_principle=False | |
| ): | |
| if mode == "sequential": | |
| num_views = len(dist_list) | |
| dist = dist_list[view_idx % num_views] | |
| elev = elev_list[view_idx % num_views] | |
| azim = azim_list[view_idx % num_views] | |
| sector = sector_list[view_idx % num_views] | |
| selected_view_ids.append(view_idx % num_views) | |
| elif mode == "heuristic": | |
| if use_principle and view_idx < 6: | |
| selected_view_idx = view_idx | |
| else: | |
| selected_view_idx = None | |
| max_heat = 0 | |
| print("=> selecting next view...") | |
| view_heat_list = [] | |
| for sample_idx in tqdm(range(len(dist_list))): | |
| view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx], | |
| sample_idx, sample_idx, view_punishments, | |
| similarity_texture_cache, exist_texture, | |
| mesh, faces, verts_uvs, | |
| image_size, faces_per_pixel, | |
| init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
| device) | |
| if view_heat > max_heat: | |
| selected_view_idx = sample_idx | |
| max_heat = view_heat | |
| view_heat_list.append(view_heat.item()) | |
| print(view_heat_list) | |
| print("select view {} with heat {}".format(selected_view_idx, max_heat)) | |
| dist = dist_list[selected_view_idx] | |
| elev = elev_list[selected_view_idx] | |
| azim = azim_list[selected_view_idx] | |
| sector = sector_list[selected_view_idx] | |
| selected_view_ids.append(selected_view_idx) | |
| view_punishments[selected_view_idx] *= 0.01 | |
| elif mode == "random": | |
| selected_view_idx = random.choice(range(len(dist_list))) | |
| dist = dist_list[selected_view_idx] | |
| elev = elev_list[selected_view_idx] | |
| azim = azim_list[selected_view_idx] | |
| sector = sector_list[selected_view_idx] | |
| selected_view_ids.append(selected_view_idx) | |
| else: | |
| raise NotImplementedError() | |
| return dist, elev, azim, sector, selected_view_ids, view_punishments | |
| def build_backproject_mask(mesh, faces, verts_uvs, | |
| cameras, reference_image, faces_per_pixel, | |
| image_size, uv_size, device): | |
| # construct pixel UVs | |
| renderer_scaled = init_renderer(cameras, | |
| shader=init_soft_phong_shader( | |
| camera=cameras, | |
| blend_params=BlendParams(), | |
| device=device), | |
| image_size=image_size, | |
| faces_per_pixel=faces_per_pixel | |
| ) | |
| fragments_scaled = renderer_scaled.rasterizer(mesh) | |
| # get UV coordinates for each pixel | |
| faces_verts_uvs = verts_uvs[faces.textures_idx] | |
| pixel_uvs = interpolate_face_attributes( | |
| fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs | |
| ) # NxHsxWsxKx2 | |
| pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2) | |
| texture_locations_y, texture_locations_x = get_all_4_locations( | |
| (1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1), | |
| pixel_uvs[:, 0].reshape(-1) * (uv_size - 1) | |
| ) | |
| K = faces_per_pixel | |
| texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255. | |
| texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1]) | |
| # texture | |
| texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device) | |
| texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3) | |
| return texture_tensor[:, :, 0] | |
| def build_diffusion_mask(mesh_stuff, | |
| renderer, exist_texture, similarity_texture_cache, target_value, device, image_size, | |
| smooth_mask=False, view_threshold=0.01): | |
| mesh, faces, verts_uvs = mesh_stuff | |
| mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!! | |
| # visible mask => the whole region | |
| exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device) | |
| mask_mesh.textures = TexturesUV( | |
| maps=torch.ones_like(exist_texture_expand), | |
| faces_uvs=faces.textures_idx[None, ...], | |
| verts_uvs=verts_uvs[None, ...], | |
| sampling_mode="nearest" | |
| ) | |
| # visible_mask_tensor, *_ = render(mask_mesh, renderer) | |
| visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer) | |
| # faces that are too rotated away from the viewpoint will be treated as invisible | |
| valid_mask_tensor = (similarity_map_tensor >= view_threshold).float() | |
| visible_mask_tensor *= valid_mask_tensor | |
| # nonexist mask <=> new mask | |
| exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device) | |
| mask_mesh.textures = TexturesUV( | |
| maps=1 - exist_texture_expand, | |
| faces_uvs=faces.textures_idx[None, ...], | |
| verts_uvs=verts_uvs[None, ...], | |
| sampling_mode="nearest" | |
| ) | |
| new_mask_tensor, *_ = render(mask_mesh, renderer) | |
| new_mask_tensor *= valid_mask_tensor | |
| # exist mask => visible mask - new mask | |
| exist_mask_tensor = visible_mask_tensor - new_mask_tensor | |
| exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow | |
| # all update mask | |
| mask_mesh.textures = TexturesUV( | |
| maps=( | |
| similarity_texture_cache.argmax(0) == target_value | |
| # # only consider the views that have already appeared before | |
| # similarity_texture_cache[0:target_value+1].argmax(0) == target_value | |
| ).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device), | |
| faces_uvs=faces.textures_idx[None, ...], | |
| verts_uvs=verts_uvs[None, ...], | |
| sampling_mode="nearest" | |
| ) | |
| all_update_mask_tensor, *_ = render(mask_mesh, renderer) | |
| # current update mask => intersection between all update mask and exist mask | |
| update_mask_tensor = exist_mask_tensor * all_update_mask_tensor | |
| # keep mask => exist mask - update mask | |
| old_mask_tensor = exist_mask_tensor - update_mask_tensor | |
| # convert | |
| new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1) | |
| new_mask = transforms.ToPILImage()(new_mask).convert("L") | |
| update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1) | |
| update_mask = transforms.ToPILImage()(update_mask).convert("L") | |
| old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1) | |
| old_mask = transforms.ToPILImage()(old_mask).convert("L") | |
| exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1) | |
| exist_mask = transforms.ToPILImage()(exist_mask).convert("L") | |
| return new_mask, update_mask, old_mask, exist_mask | |
| def render_one_view(mesh, | |
| dist, elev, azim, | |
| image_size, faces_per_pixel, | |
| device): | |
| # render the view | |
| cameras = init_camera( | |
| dist, elev, azim, | |
| image_size, device | |
| ) | |
| renderer = init_renderer(cameras, | |
| shader=init_soft_phong_shader( | |
| camera=cameras, | |
| blend_params=BlendParams(), | |
| device=device), | |
| image_size=image_size, | |
| faces_per_pixel=faces_per_pixel | |
| ) | |
| init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer) | |
| return ( | |
| cameras, renderer, | |
| init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments | |
| ) | |
| def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs, | |
| dist_list, elev_list, azim_list, | |
| image_size, image_size_scaled, uv_size, faces_per_pixel, | |
| device): | |
| num_candidate_views = len(dist_list) | |
| similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device) | |
| print("=> building similarity texture cache for all views...") | |
| for i in tqdm(range(num_candidate_views)): | |
| cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh, | |
| dist_list[i], elev_list[i], azim_list[i], | |
| image_size, faces_per_pixel, device) | |
| similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs, | |
| cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel, | |
| image_size_scaled, uv_size, device) | |
| return similarity_texture_cache | |
| def render_one_view_and_build_masks(dist, elev, azim, | |
| selected_view_idx, view_idx, view_punishments, | |
| similarity_texture_cache, exist_texture, | |
| mesh, faces, verts_uvs, | |
| image_size, faces_per_pixel, | |
| init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
| device, save_intermediate=False, smooth_mask=False, view_threshold=0.01): | |
| # render the view | |
| ( | |
| cameras, renderer, | |
| init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments | |
| ) = render_one_view(mesh, | |
| dist, elev, azim, | |
| image_size, faces_per_pixel, | |
| device | |
| ) | |
| init_image = init_images_tensor[0].cpu() | |
| init_image = init_image.permute(2, 0, 1) | |
| init_image = transforms.ToPILImage()(init_image).convert("RGB") | |
| normal_map = normal_maps_tensor[0].cpu() | |
| normal_map = normal_map.permute(2, 0, 1) | |
| normal_map = transforms.ToPILImage()(normal_map).convert("RGB") | |
| depth_map = depth_maps_tensor[0].cpu().numpy() | |
| depth_map = Image.fromarray(depth_map).convert("L") | |
| similarity_map = similarity_tensor[0, :, :, 0].cpu() | |
| similarity_map = transforms.ToPILImage()(similarity_map).convert("L") | |
| flat_renderer = init_renderer(cameras, | |
| shader=init_flat_texel_shader( | |
| camera=cameras, | |
| device=device), | |
| image_size=image_size, | |
| faces_per_pixel=faces_per_pixel | |
| ) | |
| new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask( | |
| (mesh, faces, verts_uvs), | |
| flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size, | |
| smooth_mask=smooth_mask, view_threshold=view_threshold | |
| ) | |
| # NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`) | |
| # it should match with `similarity_texture_cache` | |
| ( | |
| old_mask_tensor, | |
| update_mask_tensor, | |
| new_mask_tensor, | |
| all_mask_tensor, | |
| quad_mask_tensor | |
| ) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device) | |
| view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor) | |
| view_heat *= view_punishments[selected_view_idx] | |
| # save intermediate results | |
| if save_intermediate: | |
| init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx))) | |
| normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx))) | |
| depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx))) | |
| similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx))) | |
| new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx))) | |
| update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx))) | |
| old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx))) | |
| exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx))) | |
| visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device) | |
| return ( | |
| view_heat, | |
| renderer, cameras, fragments, | |
| init_image, normal_map, depth_map, | |
| init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor, | |
| old_mask_image, update_mask_image, new_mask_image, | |
| old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor | |
| ) | |
| def backproject_from_image(mesh, faces, verts_uvs, cameras, | |
| reference_image, new_mask_image, update_mask_image, | |
| init_texture, exist_texture, | |
| image_size, uv_size, faces_per_pixel, | |
| device): | |
| # construct pixel UVs | |
| renderer_scaled = init_renderer(cameras, | |
| shader=init_soft_phong_shader( | |
| camera=cameras, | |
| blend_params=BlendParams(), | |
| device=device), | |
| image_size=image_size, | |
| faces_per_pixel=faces_per_pixel | |
| ) | |
| fragments_scaled = renderer_scaled.rasterizer(mesh) | |
| # get UV coordinates for each pixel | |
| faces_verts_uvs = verts_uvs[faces.textures_idx] | |
| pixel_uvs = interpolate_face_attributes( | |
| fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs | |
| ) # NxHsxWsxKx2 | |
| pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2) | |
| # the update mask has to be on top of the diffusion mask | |
| new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1) | |
| update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1) | |
| project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float() | |
| project_mask_image = project_mask_image_tensor * 255. | |
| project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8)) | |
| project_mask_image_scaled = project_mask_image.resize( | |
| (image_size, image_size), | |
| Image.Resampling.NEAREST | |
| ) | |
| project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device) | |
| pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1] | |
| texture_locations_y, texture_locations_x = get_all_4_locations( | |
| (1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1), | |
| pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1) | |
| ) | |
| K = pixel_uvs.shape[0] | |
| project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3) | |
| texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))) | |
| texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1]) | |
| texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3) | |
| # texture | |
| texture_tensor = torch.from_numpy(np.array(init_texture)).to(device) | |
| texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked | |
| init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8)) | |
| # update texture cache | |
| exist_texture[texture_locations_y, texture_locations_x] = 1 | |
| return init_texture, project_mask_image, exist_texture | |