Spaces:
Running
on
L4
Running
on
L4
Commit
·
390338e
1
Parent(s):
d5a5fa0
Refactor inference configuration and pipeline logic; removed unused parameters and improved frame selection process. Updated inference settings in inference.yaml and streamlined surfel model initialization in pipeline.py.
Browse files- configs/inference/inference.yaml +2 -22
- modeling/pipeline.py +88 -243
configs/inference/inference.yaml
CHANGED
@@ -4,20 +4,15 @@ model:
|
|
4 |
width: 576
|
5 |
original_height: 288
|
6 |
original_width: 512
|
7 |
-
|
8 |
-
# pretrained_model_path: "stabilityai/stable-diffusion-2-1"
|
9 |
-
# pretrained_video_model_path: "stabilityai/stable-video-diffusion-img2vid"
|
10 |
|
11 |
context_num_frames: 4
|
12 |
target_num_frames: 4
|
13 |
num_frames: 8
|
14 |
vae_spatial_scale: 8
|
15 |
latent_channels: 4
|
16 |
-
# num_ray_blocks: 2
|
17 |
vae_scale_factor: 8
|
18 |
-
inference_mode: false
|
19 |
|
20 |
-
temporal_only: false
|
21 |
use_non_maximum_suppression: true
|
22 |
translation_distance_weight: 0.1
|
23 |
|
@@ -26,14 +21,7 @@ model:
|
|
26 |
cfg_min: 1.2
|
27 |
cfg: 2.0
|
28 |
guider_types: 1
|
29 |
-
|
30 |
samples_dir: "./visualization"
|
31 |
-
save_flag: false
|
32 |
-
use_wandb: false
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
# model_path: "/homes/55/runjia/storage/simview_weights/2025-04-30_12-08-55/checkpoint_230000.pth"
|
37 |
model_path: "liguang0115/vmem"
|
38 |
|
39 |
|
@@ -45,7 +33,7 @@ surfel:
|
|
45 |
merge_position_threshold: 0.2
|
46 |
merge_normal_threshold: 0.6
|
47 |
lr: 0.01
|
48 |
-
niter:
|
49 |
model_path: "liguang0115/cut3r"
|
50 |
width: 512
|
51 |
height: 288
|
@@ -54,14 +42,6 @@ inference:
|
|
54 |
visualize: true
|
55 |
visualize_pointcloud: false
|
56 |
visualize_surfel: false
|
57 |
-
save_surfels: false
|
58 |
-
image_dir: "/homes/55/runjia/storage/realestate10k/video_data/test"
|
59 |
-
meta_info_dir: "/homes/55/runjia/storage/realestate10k/RealEstate10K/test"
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
|
67 |
|
|
|
4 |
width: 576
|
5 |
original_height: 288
|
6 |
original_width: 512
|
7 |
+
|
|
|
|
|
8 |
|
9 |
context_num_frames: 4
|
10 |
target_num_frames: 4
|
11 |
num_frames: 8
|
12 |
vae_spatial_scale: 8
|
13 |
latent_channels: 4
|
|
|
14 |
vae_scale_factor: 8
|
|
|
15 |
|
|
|
16 |
use_non_maximum_suppression: true
|
17 |
translation_distance_weight: 0.1
|
18 |
|
|
|
21 |
cfg_min: 1.2
|
22 |
cfg: 2.0
|
23 |
guider_types: 1
|
|
|
24 |
samples_dir: "./visualization"
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
model_path: "liguang0115/vmem"
|
26 |
|
27 |
|
|
|
33 |
merge_position_threshold: 0.2
|
34 |
merge_normal_threshold: 0.6
|
35 |
lr: 0.01
|
36 |
+
niter: 400
|
37 |
model_path: "liguang0115/cut3r"
|
38 |
width: 512
|
39 |
height: 288
|
|
|
42 |
visualize: true
|
43 |
visualize_pointcloud: false
|
44 |
visualize_surfel: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
|
modeling/pipeline.py
CHANGED
@@ -4,27 +4,19 @@ from copy import deepcopy
|
|
4 |
|
5 |
import math
|
6 |
|
7 |
-
# import matplotlib.pyplot as plt
|
8 |
-
# from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
9 |
-
|
10 |
import PIL
|
11 |
-
from PIL import Image, ImageOps
|
12 |
import numpy as np
|
13 |
from einops import repeat
|
14 |
-
|
15 |
|
16 |
import torch
|
17 |
import torch.nn.functional as F
|
18 |
-
from torch.amp import autocast
|
19 |
import torchvision.transforms as tvf
|
20 |
|
21 |
|
22 |
-
# from diffusers import AutoencoderKL, DiffusionPipeline
|
23 |
-
# from diffusers.schedulers import DDIMScheduler
|
24 |
from diffusers.utils import export_to_gif
|
25 |
|
26 |
import sys
|
27 |
-
# Add CUT3R to Python path for imports
|
28 |
sys.path.append("./extern/CUT3R")
|
29 |
from extern.CUT3R.surfel_inference import run_inference_from_pil
|
30 |
from extern.CUT3R.add_ckpt_path import add_path_to_dust3r
|
@@ -91,32 +83,23 @@ class VMemPipeline:
|
|
91 |
self.device = device
|
92 |
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
# Import CUT3R scene alignment module
|
108 |
-
from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode
|
109 |
-
self.GlobalAlignerMode = GlobalAlignerMode
|
110 |
-
self.global_aligner = global_aligner
|
111 |
|
112 |
|
113 |
|
114 |
-
|
115 |
-
self.surfel_model = None
|
116 |
|
117 |
|
118 |
-
|
119 |
-
self.temporal_only = self.config.model.temporal_only
|
120 |
self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression
|
121 |
|
122 |
self.context_num_frames = self.config.model.context_num_frames
|
@@ -537,33 +520,58 @@ class VMemPipeline:
|
|
537 |
embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices]
|
538 |
intrinsics = [self.Ks[i] for i in indices]
|
539 |
return c2ws, latents, embeddings, intrinsics, indices
|
540 |
-
|
541 |
-
if self.
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
|
552 |
-
|
553 |
-
|
554 |
-
max_frames = min(self.config.model.context_num_frames, len(
|
555 |
|
556 |
-
|
557 |
-
is_first_step = len(self.pil_frames) <= 1
|
558 |
is_second_step = len(self.pil_frames) == 5
|
559 |
-
|
560 |
|
561 |
# Adaptively determine initial threshold based on camera pose distribution
|
562 |
if use_non_maximum_suppression is None:
|
563 |
use_non_maximum_suppression = self.use_non_maximum_suppression
|
564 |
|
565 |
if use_non_maximum_suppression:
|
566 |
-
|
567 |
if is_second_step:
|
568 |
# Calculate pairwise distances between existing frames
|
569 |
pairwise_distances = []
|
@@ -581,32 +589,26 @@ class VMemPipeline:
|
|
581 |
pairwise_distances.sort()
|
582 |
percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
|
583 |
self.initial_threshold = pairwise_distances[percentile_idx]
|
584 |
-
|
585 |
-
# Ensure threshold is within reasonable bounds
|
586 |
-
# initial_threshold = max(0.00, min(0.001, initial_threshold))
|
587 |
else:
|
588 |
-
self.initial_threshold =
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
else:
|
593 |
self.initial_threshold = 1e8
|
594 |
-
|
595 |
-
|
596 |
|
597 |
selected_indices = []
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
|
599 |
# Try with increasingly relaxed thresholds until we get enough frames
|
600 |
-
current_threshold
|
601 |
-
while len(selected_indices) < min_required_frames and current_threshold <= 1.0:
|
602 |
-
# Reset selection with new threshold
|
603 |
-
selected_indices = []
|
604 |
-
|
605 |
-
# Always start with the closest pose
|
606 |
-
selected_indices.append(sorted_indices[0])
|
607 |
-
|
608 |
# Try to add each subsequent pose in order of distance
|
609 |
-
for idx in
|
610 |
if len(selected_indices) >= max_frames:
|
611 |
break
|
612 |
|
@@ -627,148 +629,22 @@ class VMemPipeline:
|
|
627 |
selected_indices.append(idx)
|
628 |
|
629 |
# If we still don't have enough frames, relax the threshold and try again
|
630 |
-
if len(selected_indices) <
|
631 |
-
current_threshold
|
632 |
else:
|
633 |
break
|
634 |
|
635 |
# If we still don't have enough frames, just take the top frames by distance
|
636 |
-
if len(selected_indices) <
|
637 |
available_indices = []
|
638 |
-
for idx in
|
639 |
if idx not in selected_indices:
|
640 |
available_indices.append(idx)
|
641 |
-
selected_indices.extend(available_indices[:
|
642 |
|
643 |
# Convert to tensor and maintain original order (don't reverse)
|
644 |
-
context_time_indices = torch.
|
645 |
-
|
646 |
-
|
647 |
-
else:
|
648 |
-
if len(self.pil_frames) == 1:
|
649 |
-
context_time_indices = [0]
|
650 |
-
else:
|
651 |
-
# get the average camera pose
|
652 |
-
average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:])
|
653 |
-
transformed_average_c2w = self.get_transformed_c2ws(average_c2w)
|
654 |
-
target_K = np.mean(self.surfel_Ks, axis=0)
|
655 |
-
# Select frames using surfel-based relevance
|
656 |
-
retrieved_info = self.render_surfels_to_image(
|
657 |
-
self.surfels,
|
658 |
-
transformed_average_c2w,
|
659 |
-
[target_K*0.65] * 2,
|
660 |
-
principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)),
|
661 |
-
image_width=int(self.config.surfel.width),
|
662 |
-
image_height=int(self.config.surfel.height)
|
663 |
-
)
|
664 |
-
_, frame_count = self.process_retrieved_spatial_information(retrieved_info)
|
665 |
-
if self.config.inference.visualize:
|
666 |
-
visualize_depth(retrieved_info["depth"],
|
667 |
-
visualization_dir=self.visualize_dir,
|
668 |
-
file_name=f"retrieved_depth_surfels.png",
|
669 |
-
size=(self.width, self.height))
|
670 |
-
|
671 |
-
|
672 |
-
# Build candidate frames based on relevance count
|
673 |
-
candidates = []
|
674 |
-
for frame, count in frame_count:
|
675 |
-
candidates.extend([frame] * count)
|
676 |
-
indices_to_frame = {
|
677 |
-
i: frame for i, frame in enumerate(candidates)
|
678 |
-
}
|
679 |
-
|
680 |
-
# Sort candidates by distance to target view
|
681 |
-
distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype),
|
682 |
-
torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype),
|
683 |
-
weight_translation=self.config.model.translation_distance_weight).item()
|
684 |
-
for frame in candidates]
|
685 |
-
|
686 |
-
sorted_indices = torch.argsort(torch.tensor(distances))
|
687 |
-
sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices]
|
688 |
-
max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents))
|
689 |
-
|
690 |
-
|
691 |
-
is_second_step = len(self.pil_frames) == 5
|
692 |
-
|
693 |
-
|
694 |
-
# Adaptively determine initial threshold based on camera pose distribution
|
695 |
-
if use_non_maximum_suppression is None:
|
696 |
-
use_non_maximum_suppression = self.use_non_maximum_suppression
|
697 |
-
|
698 |
-
if use_non_maximum_suppression:
|
699 |
-
if is_second_step:
|
700 |
-
# Calculate pairwise distances between existing frames
|
701 |
-
pairwise_distances = []
|
702 |
-
for i in range(len(self.c2ws)):
|
703 |
-
for j in range(i+1, len(self.c2ws)):
|
704 |
-
sim = self.geodesic_distance(
|
705 |
-
torch.from_numpy(np.array(self.c2ws[i])).to(self.device, self.dtype),
|
706 |
-
torch.from_numpy(np.array(self.c2ws[j])).to(self.device, self.dtype),
|
707 |
-
weight_translation=self.config.model.translation_distance_weight
|
708 |
-
)
|
709 |
-
pairwise_distances.append(sim.item())
|
710 |
-
|
711 |
-
if pairwise_distances:
|
712 |
-
# Sort distances and take percentile as threshold
|
713 |
-
pairwise_distances.sort()
|
714 |
-
percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
|
715 |
-
self.initial_threshold = pairwise_distances[percentile_idx]
|
716 |
-
else:
|
717 |
-
self.initial_threshold = 1
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
else:
|
722 |
-
self.initial_threshold = 1e8
|
723 |
-
|
724 |
-
selected_indices = []
|
725 |
-
current_threshold = self.initial_threshold
|
726 |
-
|
727 |
-
# Always start with the closest pose
|
728 |
-
selected_indices.append(sorted_frames[0])
|
729 |
-
if not use_non_maximum_suppression:
|
730 |
-
selected_indices.append(len(self.c2ws) - 1)
|
731 |
-
|
732 |
-
# Try with increasingly relaxed thresholds until we get enough frames
|
733 |
-
while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression:
|
734 |
-
# Try to add each subsequent pose in order of distance
|
735 |
-
for idx in sorted_frames[1:]:
|
736 |
-
if len(selected_indices) >= max_frames:
|
737 |
-
break
|
738 |
-
|
739 |
-
# Check if this candidate is sufficiently different from all selected frames
|
740 |
-
is_too_similar = False
|
741 |
-
for selected_idx in selected_indices:
|
742 |
-
similarity = self.geodesic_distance(
|
743 |
-
torch.from_numpy(np.array(self.c2ws[idx])).to(self.device, self.dtype),
|
744 |
-
torch.from_numpy(np.array(self.c2ws[selected_idx])).to(self.device, self.dtype),
|
745 |
-
weight_translation=self.config.model.translation_distance_weight
|
746 |
-
)
|
747 |
-
if similarity < current_threshold:
|
748 |
-
is_too_similar = True
|
749 |
-
break
|
750 |
-
|
751 |
-
# Add to selected frames if not too similar to any existing selection
|
752 |
-
if not is_too_similar:
|
753 |
-
selected_indices.append(idx)
|
754 |
-
|
755 |
-
# If we still don't have enough frames, relax the threshold and try again
|
756 |
-
if len(selected_indices) < max_frames:
|
757 |
-
current_threshold /= 1.2
|
758 |
-
else:
|
759 |
-
break
|
760 |
-
|
761 |
-
# If we still don't have enough frames, just take the top frames by distance
|
762 |
-
if len(selected_indices) < max_frames:
|
763 |
-
available_indices = []
|
764 |
-
for idx in sorted_frames:
|
765 |
-
if idx not in selected_indices:
|
766 |
-
available_indices.append(idx)
|
767 |
-
selected_indices.extend(available_indices[:max_frames-len(selected_indices)])
|
768 |
-
|
769 |
-
# Convert to tensor and maintain original order (don't reverse)
|
770 |
-
context_time_indices = torch.from_numpy(np.array(selected_indices))
|
771 |
-
context_data = prepare_context_data(context_time_indices)
|
772 |
|
773 |
(context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data
|
774 |
print(f"context_time_indices: {context_time_indices}")
|
@@ -992,11 +868,7 @@ class VMemPipeline:
|
|
992 |
# Flip Y and Z components of camera poses to match dataset convention
|
993 |
c2ws_transformed = self.get_transformed_c2ws()
|
994 |
|
995 |
-
|
996 |
-
if self.global_step == 10:
|
997 |
-
visualize = True
|
998 |
-
else:
|
999 |
-
visualize = False
|
1000 |
scene = run_inference_from_pil(
|
1001 |
input_images,
|
1002 |
self.surfel_model,
|
@@ -1004,8 +876,7 @@ class VMemPipeline:
|
|
1004 |
depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None,
|
1005 |
lr = lr,
|
1006 |
niter = niter,
|
1007 |
-
|
1008 |
-
visualize=visualize,
|
1009 |
device=device,
|
1010 |
)
|
1011 |
|
@@ -1043,12 +914,10 @@ class VMemPipeline:
|
|
1043 |
)
|
1044 |
confs = confs.squeeze(1)
|
1045 |
|
1046 |
-
|
1047 |
-
# self.surfel_to_timestep = {}
|
1048 |
start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames
|
1049 |
end_idx = len(pointcloud)
|
1050 |
-
|
1051 |
-
# Create surfels for the current frame
|
1052 |
for frame_idx in range(start_idx, end_idx):
|
1053 |
surfels = self.pointmap_to_surfels(
|
1054 |
pointmap=pointcloud[frame_idx],
|
@@ -1077,30 +946,6 @@ class VMemPipeline:
|
|
1077 |
for surfel_index in range(num_surfels):
|
1078 |
self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx]
|
1079 |
|
1080 |
-
# Save surfels if configured
|
1081 |
-
if self.config.inference.save_surfels and len(self.surfels) > 0:
|
1082 |
-
positions = np.array([s.position for s in surfels], dtype=np.float32)
|
1083 |
-
normals = np.array([s.normal for s in surfels], dtype=np.float32)
|
1084 |
-
radii = np.array([s.radius for s in surfels], dtype=np.float32)
|
1085 |
-
colors = np.array([s.color for s in surfels], dtype=np.float32)
|
1086 |
-
|
1087 |
-
np.savez(f"{self.config.visualization_dir}/surfels_added.npz",
|
1088 |
-
positions=positions,
|
1089 |
-
normals=normals,
|
1090 |
-
radii=radii,
|
1091 |
-
colors=colors)
|
1092 |
-
|
1093 |
-
positions = np.array([s.position for s in self.surfels], dtype=np.float32)
|
1094 |
-
normals = np.array([s.normal for s in self.surfels], dtype=np.float32)
|
1095 |
-
radii = np.array([s.radius for s in self.surfels], dtype=np.float32)
|
1096 |
-
colors = np.array([s.color for s in self.surfels], dtype=np.float32)
|
1097 |
-
|
1098 |
-
np.savez(f"{self.config.visualization_dir}/surfels_original.npz",
|
1099 |
-
positions=positions,
|
1100 |
-
normals=normals,
|
1101 |
-
radii=radii,
|
1102 |
-
colors=colors)
|
1103 |
-
|
1104 |
self.surfels.extend(surfels)
|
1105 |
|
1106 |
if self.config.inference.visualize_surfel:
|
@@ -1323,12 +1168,12 @@ class VMemPipeline:
|
|
1323 |
self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png")
|
1324 |
|
1325 |
# Update scene reconstruction if needed
|
1326 |
-
|
1327 |
-
|
1328 |
-
|
1329 |
-
|
1330 |
-
|
1331 |
-
|
1332 |
self.global_step += 1
|
1333 |
|
1334 |
if self.config.inference.visualize:
|
@@ -1386,9 +1231,9 @@ class VMemPipeline:
|
|
1386 |
|
1387 |
# Handle surfels if using reconstructor
|
1388 |
self.global_step -= frames_to_remove
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
|
1393 |
|
1394 |
# Find surfels that belong only to the removed timesteps
|
|
|
4 |
|
5 |
import math
|
6 |
|
|
|
|
|
|
|
7 |
import PIL
|
|
|
8 |
import numpy as np
|
9 |
from einops import repeat
|
10 |
+
|
11 |
|
12 |
import torch
|
13 |
import torch.nn.functional as F
|
|
|
14 |
import torchvision.transforms as tvf
|
15 |
|
16 |
|
|
|
|
|
17 |
from diffusers.utils import export_to_gif
|
18 |
|
19 |
import sys
|
|
|
20 |
sys.path.append("./extern/CUT3R")
|
21 |
from extern.CUT3R.surfel_inference import run_inference_from_pil
|
22 |
from extern.CUT3R.add_ckpt_path import add_path_to_dust3r
|
|
|
83 |
self.device = device
|
84 |
|
85 |
|
86 |
+
surfel_model_path = hf_hub_download(repo_id=self.config.surfel.model_path, filename="cut3r_512_dpt_4_64.pth")
|
87 |
+
print(f"Loading model from {surfel_model_path}...")
|
88 |
+
add_path_to_dust3r(surfel_model_path)
|
89 |
+
self.surfel_model = ARCroco3DStereo.from_pretrained(surfel_model_path).to(device)
|
90 |
+
self.surfel_model.eval()
|
91 |
+
|
92 |
+
|
93 |
+
# Import CUT3R scene alignment module
|
94 |
+
from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode
|
95 |
+
self.GlobalAlignerMode = GlobalAlignerMode
|
96 |
+
self.global_aligner = global_aligner
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
|
99 |
|
100 |
+
|
|
|
101 |
|
102 |
|
|
|
|
|
103 |
self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression
|
104 |
|
105 |
self.context_num_frames = self.config.model.context_num_frames
|
|
|
520 |
embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices]
|
521 |
intrinsics = [self.Ks[i] for i in indices]
|
522 |
return c2ws, latents, embeddings, intrinsics, indices
|
523 |
+
|
524 |
+
if len(self.pil_frames) == 1:
|
525 |
+
context_time_indices = [0]
|
526 |
+
else:
|
527 |
+
# get the average camera pose
|
528 |
+
average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:])
|
529 |
+
transformed_average_c2w = self.get_transformed_c2ws(average_c2w)
|
530 |
+
target_K = np.mean(self.surfel_Ks, axis=0)
|
531 |
+
# Select frames using surfel-based relevance
|
532 |
+
retrieved_info = self.render_surfels_to_image(
|
533 |
+
self.surfels,
|
534 |
+
transformed_average_c2w,
|
535 |
+
[target_K*0.65] * 2,
|
536 |
+
principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)),
|
537 |
+
image_width=int(self.config.surfel.width),
|
538 |
+
image_height=int(self.config.surfel.height)
|
539 |
+
)
|
540 |
+
_, frame_count = self.process_retrieved_spatial_information(retrieved_info)
|
541 |
+
if self.config.inference.visualize:
|
542 |
+
visualize_depth(retrieved_info["depth"],
|
543 |
+
visualization_dir=self.visualize_dir,
|
544 |
+
file_name=f"retrieved_depth_surfels.png",
|
545 |
+
size=(self.width, self.height))
|
546 |
+
|
547 |
+
|
548 |
+
# Build candidate frames based on relevance count
|
549 |
+
candidates = []
|
550 |
+
for frame, count in frame_count:
|
551 |
+
candidates.extend([frame] * count)
|
552 |
+
indices_to_frame = {
|
553 |
+
i: frame for i, frame in enumerate(candidates)
|
554 |
+
}
|
555 |
+
|
556 |
+
# Sort candidates by distance to target view
|
557 |
+
distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype),
|
558 |
+
torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype),
|
559 |
+
weight_translation=self.config.model.translation_distance_weight).item()
|
560 |
+
for frame in candidates]
|
561 |
|
562 |
+
sorted_indices = torch.argsort(torch.tensor(distances))
|
563 |
+
sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices]
|
564 |
+
max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents))
|
565 |
|
566 |
+
|
|
|
567 |
is_second_step = len(self.pil_frames) == 5
|
568 |
+
|
569 |
|
570 |
# Adaptively determine initial threshold based on camera pose distribution
|
571 |
if use_non_maximum_suppression is None:
|
572 |
use_non_maximum_suppression = self.use_non_maximum_suppression
|
573 |
|
574 |
if use_non_maximum_suppression:
|
|
|
575 |
if is_second_step:
|
576 |
# Calculate pairwise distances between existing frames
|
577 |
pairwise_distances = []
|
|
|
589 |
pairwise_distances.sort()
|
590 |
percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
|
591 |
self.initial_threshold = pairwise_distances[percentile_idx]
|
|
|
|
|
|
|
592 |
else:
|
593 |
+
self.initial_threshold = 1
|
594 |
+
|
595 |
+
|
596 |
+
|
597 |
else:
|
598 |
self.initial_threshold = 1e8
|
|
|
|
|
599 |
|
600 |
selected_indices = []
|
601 |
+
current_threshold = self.initial_threshold
|
602 |
+
|
603 |
+
# Always start with the closest pose
|
604 |
+
selected_indices.append(sorted_frames[0])
|
605 |
+
if not use_non_maximum_suppression:
|
606 |
+
selected_indices.append(len(self.c2ws) - 1)
|
607 |
|
608 |
# Try with increasingly relaxed thresholds until we get enough frames
|
609 |
+
while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
# Try to add each subsequent pose in order of distance
|
611 |
+
for idx in sorted_frames[1:]:
|
612 |
if len(selected_indices) >= max_frames:
|
613 |
break
|
614 |
|
|
|
629 |
selected_indices.append(idx)
|
630 |
|
631 |
# If we still don't have enough frames, relax the threshold and try again
|
632 |
+
if len(selected_indices) < max_frames:
|
633 |
+
current_threshold /= 1.2
|
634 |
else:
|
635 |
break
|
636 |
|
637 |
# If we still don't have enough frames, just take the top frames by distance
|
638 |
+
if len(selected_indices) < max_frames:
|
639 |
available_indices = []
|
640 |
+
for idx in sorted_frames:
|
641 |
if idx not in selected_indices:
|
642 |
available_indices.append(idx)
|
643 |
+
selected_indices.extend(available_indices[:max_frames-len(selected_indices)])
|
644 |
|
645 |
# Convert to tensor and maintain original order (don't reverse)
|
646 |
+
context_time_indices = torch.from_numpy(np.array(selected_indices))
|
647 |
+
context_data = prepare_context_data(context_time_indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
|
649 |
(context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data
|
650 |
print(f"context_time_indices: {context_time_indices}")
|
|
|
868 |
# Flip Y and Z components of camera poses to match dataset convention
|
869 |
c2ws_transformed = self.get_transformed_c2ws()
|
870 |
|
871 |
+
|
|
|
|
|
|
|
|
|
872 |
scene = run_inference_from_pil(
|
873 |
input_images,
|
874 |
self.surfel_model,
|
|
|
876 |
depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None,
|
877 |
lr = lr,
|
878 |
niter = niter,
|
879 |
+
visualize=self.config.inference.visualize_surfel,
|
|
|
880 |
device=device,
|
881 |
)
|
882 |
|
|
|
914 |
)
|
915 |
confs = confs.squeeze(1)
|
916 |
|
917 |
+
|
|
|
918 |
start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames
|
919 |
end_idx = len(pointcloud)
|
920 |
+
|
|
|
921 |
for frame_idx in range(start_idx, end_idx):
|
922 |
surfels = self.pointmap_to_surfels(
|
923 |
pointmap=pointcloud[frame_idx],
|
|
|
946 |
for surfel_index in range(num_surfels):
|
947 |
self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx]
|
948 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
949 |
self.surfels.extend(surfels)
|
950 |
|
951 |
if self.config.inference.visualize_surfel:
|
|
|
1168 |
self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png")
|
1169 |
|
1170 |
# Update scene reconstruction if needed
|
1171 |
+
|
1172 |
+
self.construct_and_store_scene(self.pil_frames,
|
1173 |
+
time_indices=context_time_indices,
|
1174 |
+
niter=self.config.surfel.niter,
|
1175 |
+
lr=self.config.surfel.lr,
|
1176 |
+
device=self.device)
|
1177 |
self.global_step += 1
|
1178 |
|
1179 |
if self.config.inference.visualize:
|
|
|
1231 |
|
1232 |
# Handle surfels if using reconstructor
|
1233 |
self.global_step -= frames_to_remove
|
1234 |
+
|
1235 |
+
for _ in range(frames_to_remove):
|
1236 |
+
self.surfel_depths.pop()
|
1237 |
|
1238 |
|
1239 |
# Find surfels that belong only to the removed timesteps
|