Spaces:
Paused
Paused
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import trimesh | |
| import sys | |
| import os | |
| sys.path.append('vggsfm_code/') | |
| import shutil | |
| from datetime import datetime | |
| from vggsfm_code.hf_demo import demo_fn | |
| from omegaconf import DictConfig, OmegaConf | |
| from viz_utils.viz_fn import add_camera | |
| import glob | |
| # | |
| from scipy.spatial.transform import Rotation | |
| import PIL | |
| import gc | |
| # import spaces | |
| # @spaces.GPU | |
| def vggsfm_demo( | |
| input_video, | |
| input_image, | |
| query_frame_num, | |
| max_query_pts=4096, | |
| ): | |
| torch.cuda.empty_cache() | |
| if input_video is not None: | |
| if not isinstance(input_video, str): | |
| input_video = input_video["video"]["path"] | |
| cfg_file = "vggsfm_code/cfgs/demo.yaml" | |
| cfg = OmegaConf.load(cfg_file) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| max_input_image = 20 | |
| target_dir = f"input_images_{timestamp}" | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir) | |
| target_dir_images = target_dir + "/images" | |
| os.makedirs(target_dir_images) | |
| if input_image is not None: | |
| if len(input_image)<3: | |
| return None, "Please input at least three frames" | |
| input_image = sorted(input_image) | |
| input_image = input_image[:max_input_image] | |
| # Copy files to the new directory | |
| for file_name in input_image: | |
| shutil.copy(file_name, target_dir_images) | |
| elif input_video is not None: | |
| vs = cv2.VideoCapture(input_video) | |
| fps = vs.get(cv2.CAP_PROP_FPS) | |
| frame_rate = 1 | |
| frame_interval = int(fps * frame_rate) | |
| video_frame_num = 0 | |
| count = 0 | |
| while video_frame_num<=max_input_image: | |
| (gotit, frame) = vs.read() | |
| count +=1 | |
| if not gotit: | |
| break | |
| if count % frame_interval == 0: | |
| cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) | |
| video_frame_num+=1 | |
| if video_frame_num<3: | |
| return None, "Please input at least three frames" | |
| else: | |
| return None, "Input format incorrect" | |
| cfg.query_frame_num = query_frame_num | |
| cfg.max_query_pts = max_query_pts | |
| print(f"Files have been copied to {target_dir_images}") | |
| cfg.SCENE_DIR = target_dir | |
| # try: | |
| predictions = demo_fn(cfg) | |
| # except: | |
| # return None, "Something seems to be incorrect. Please verify that your inputs are formatted correctly. If the issue persists, kindly create a GitHub issue for further assistance." | |
| glbscene = vggsfm_predictions_to_glb(predictions) | |
| glbfile = target_dir + "/glbscene.glb" | |
| glbscene.export(file_obj=glbfile) | |
| del predictions | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print(input_image) | |
| print(input_video) | |
| return glbfile, "Success" | |
| def vggsfm_predictions_to_glb(predictions): | |
| # learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py | |
| points3D = predictions["points3D"].cpu().numpy() | |
| points3D_rgb = predictions["points3D_rgb"].cpu().numpy() | |
| points3D_rgb = (points3D_rgb*255).astype(np.uint8) | |
| extrinsics_opencv = predictions["extrinsics_opencv"].cpu().numpy() | |
| intrinsics_opencv = predictions["intrinsics_opencv"].cpu().numpy() | |
| raw_image_paths = predictions["raw_image_paths"] | |
| images = predictions["images"].permute(0,2,3,1).cpu().numpy() | |
| images = (images*255).astype(np.uint8) | |
| glbscene = trimesh.Scene() | |
| point_cloud = trimesh.PointCloud(points3D, colors=points3D_rgb) | |
| glbscene.add_geometry(point_cloud) | |
| camera_edge_colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204), | |
| (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)] | |
| frame_num = len(extrinsics_opencv) | |
| extrinsics_opencv_4x4 = np.zeros((frame_num, 4, 4)) | |
| extrinsics_opencv_4x4[:, :3, :4] = extrinsics_opencv | |
| extrinsics_opencv_4x4[:, 3, 3] = 1 | |
| for idx in range(frame_num): | |
| cam_from_world = extrinsics_opencv_4x4[idx] | |
| cam_to_world = np.linalg.inv(cam_from_world) | |
| cur_cam_color = camera_edge_colors[idx % len(camera_edge_colors)] | |
| cur_focal = intrinsics_opencv[idx, 0, 0] | |
| add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=(1024,1024), | |
| focal=None,screen_width=0.35) | |
| opengl_mat = np.array([[1, 0, 0, 0], | |
| [0, -1, 0, 0], | |
| [0, 0, -1, 0], | |
| [0, 0, 0, 1]]) | |
| rot = np.eye(4) | |
| rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() | |
| glbscene.apply_transform(np.linalg.inv(np.linalg.inv(extrinsics_opencv_4x4[0]) @ opengl_mat @ rot)) | |
| # Calculate the bounding box center and apply the translation | |
| bounding_box = glbscene.bounds | |
| center = (bounding_box[0] + bounding_box[1]) / 2 | |
| translation = np.eye(4) | |
| translation[:3, 3] = -center | |
| glbscene.apply_transform(translation) | |
| # glbfile = "glbscene.glb" | |
| # glbscene.export(file_obj=glbfile) | |
| return glbscene | |
| # apple_video = "vggsfm_code/examples/videos/apple_video.mp4" | |
| # os.path.join(os.path.dirname(__file__), "apple_video.mp4") | |
| british_museum_video = "vggsfm_code/examples/videos/british_museum_video.mp4" | |
| # os.path.join(os.path.dirname(__file__), "british_museum_video.mp4") | |
| cake_video = "vggsfm_code/examples/videos/cake_video.mp4" | |
| bonsai_video = "vggsfm_code/examples/videos/bonsai_video.mp4" | |
| # os.path.join(os.path.dirname(__file__), "cake_video.mp4") | |
| # apple_images = glob.glob(f'vggsfm_code/examples/apple/images/*') | |
| bonsai_images = glob.glob(f'vggsfm_code/examples/bonsai/images/*') | |
| cake_images = glob.glob(f'vggsfm_code/examples/cake/images/*') | |
| british_museum_images = glob.glob(f'vggsfm_code/examples/british_museum/images/*') | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🎨 VGGSfM: Visual Geometry Grounded Deep Structure From Motion") | |
| gr.Markdown(""" | |
| <div style="text-align: left;"> | |
| <p>Welcome to <a href="https://vggsfm.github.io/" target="_blank">VGGSfM</a> demo! | |
| This space demonstrates 3D reconstruction from input image frames. </p> | |
| <p>To get started quickly, you can click on our <strong> examples (the bottom of the page) </strong>. If you want to reconstruct your own data, simply: </p> | |
| <ul style="display: inline-block; text-align: left;"> | |
| <li>upload the images (.jpg, .png, etc.), or </li> | |
| <li>upload a video (.mp4, .mov, etc.) </li> | |
| </ul> | |
| <p>The reconstruction should normally take <strong> up to 90 second </strong>. If both images and videos are uploaded, the demo will only reconstruct the uploaded images. By default, we extract <strong> 1 image frame per second from the input video </strong>. To prevent crashes on the Hugging Face space, we currently limit reconstruction to the first 20 image frames. </p> | |
| <p>SfM methods are designed for <strong> rigid/static reconstruction </strong>. When dealing with dynamic/moving inputs, these methods may still work by focusing on the rigid parts of the scene. However, to ensure high-quality results, it is better to minimize the presence of moving objects in the input data. </p> | |
| <p>If you meet any problem, feel free to create an issue in our <a href="https://github.com/facebookresearch/vggsfm" target="_blank">GitHub Repo</a> ⭐</p> | |
| <p>(Please note that running reconstruction on Hugging Face space is slower than on a local machine.) </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_video = gr.Video(label="Input video", interactive=True) | |
| input_images = gr.File(file_count="multiple", label="Input Images", interactive=True) | |
| num_query_images = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of query images (key frames)", | |
| info="More query images usually lead to better reconstruction at lower speeds. If the viewpoint differences between your images are minimal, you can set this value to 1. ") | |
| num_query_points = gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="Number of query points", | |
| info="More query points usually lead to denser reconstruction at lower speeds.") | |
| with gr.Column(scale=3): | |
| reconstruction_output = gr.Model3D(label="Reconstruction", height=520) | |
| log_output = gr.Textbox(label="Log") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Reconstruct", scale=1) | |
| # submit_btn = gr.Button("Reconstruct", scale=1, elem_attributes={"style": "background-color: blue; color: white;"}) | |
| clear_btn = gr.ClearButton([input_video, input_images, num_query_images, num_query_points, reconstruction_output, log_output], scale=1) | |
| examples = [ | |
| [british_museum_video, british_museum_images, 2, 4096], | |
| [bonsai_video, bonsai_images, 3, 2048], | |
| [cake_video, cake_images, 3, 2048], | |
| ] | |
| gr.Examples(examples=examples, | |
| inputs=[input_video, input_images, num_query_images, num_query_points], | |
| outputs=[reconstruction_output, log_output], # Provide outputs | |
| fn=vggsfm_demo, # Provide the function | |
| cache_examples=True, | |
| ) | |
| submit_btn.click( | |
| vggsfm_demo, | |
| [input_video, input_images, num_query_images, num_query_points], | |
| [reconstruction_output, log_output], | |
| concurrency_limit=1 | |
| ) | |
| # demo.launch(debug=True, share=True) | |
| demo.queue(max_size=20).launch(show_error=True, share=True) | |
| # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True) | |
| ######################################################################################################################## | |
| # else: | |
| # import glob | |
| # files = glob.glob(f'vggsfm_code/examples/cake/images/*', recursive=True) | |
| # vggsfm_demo(files, None, None) | |
| # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True) | |