liguang0115 commited on
Commit
390338e
·
1 Parent(s): d5a5fa0

Refactor inference configuration and pipeline logic; removed unused parameters and improved frame selection process. Updated inference settings in inference.yaml and streamlined surfel model initialization in pipeline.py.

Browse files
configs/inference/inference.yaml CHANGED
@@ -4,20 +4,15 @@ model:
4
  width: 576
5
  original_height: 288
6
  original_width: 512
7
- cache_dir: "/homes/55/runjia/storage/svd_weights"
8
- # pretrained_model_path: "stabilityai/stable-diffusion-2-1"
9
- # pretrained_video_model_path: "stabilityai/stable-video-diffusion-img2vid"
10
 
11
  context_num_frames: 4
12
  target_num_frames: 4
13
  num_frames: 8
14
  vae_spatial_scale: 8
15
  latent_channels: 4
16
- # num_ray_blocks: 2
17
  vae_scale_factor: 8
18
- inference_mode: false
19
 
20
- temporal_only: false
21
  use_non_maximum_suppression: true
22
  translation_distance_weight: 0.1
23
 
@@ -26,14 +21,7 @@ model:
26
  cfg_min: 1.2
27
  cfg: 2.0
28
  guider_types: 1
29
-
30
  samples_dir: "./visualization"
31
- save_flag: false
32
- use_wandb: false
33
-
34
-
35
-
36
- # model_path: "/homes/55/runjia/storage/simview_weights/2025-04-30_12-08-55/checkpoint_230000.pth"
37
  model_path: "liguang0115/vmem"
38
 
39
 
@@ -45,7 +33,7 @@ surfel:
45
  merge_position_threshold: 0.2
46
  merge_normal_threshold: 0.6
47
  lr: 0.01
48
- niter: 1000
49
  model_path: "liguang0115/cut3r"
50
  width: 512
51
  height: 288
@@ -54,14 +42,6 @@ inference:
54
  visualize: true
55
  visualize_pointcloud: false
56
  visualize_surfel: false
57
- save_surfels: false
58
- image_dir: "/homes/55/runjia/storage/realestate10k/video_data/test"
59
- meta_info_dir: "/homes/55/runjia/storage/realestate10k/RealEstate10K/test"
60
-
61
-
62
-
63
-
64
-
65
 
66
 
67
 
 
4
  width: 576
5
  original_height: 288
6
  original_width: 512
7
+
 
 
8
 
9
  context_num_frames: 4
10
  target_num_frames: 4
11
  num_frames: 8
12
  vae_spatial_scale: 8
13
  latent_channels: 4
 
14
  vae_scale_factor: 8
 
15
 
 
16
  use_non_maximum_suppression: true
17
  translation_distance_weight: 0.1
18
 
 
21
  cfg_min: 1.2
22
  cfg: 2.0
23
  guider_types: 1
 
24
  samples_dir: "./visualization"
 
 
 
 
 
 
25
  model_path: "liguang0115/vmem"
26
 
27
 
 
33
  merge_position_threshold: 0.2
34
  merge_normal_threshold: 0.6
35
  lr: 0.01
36
+ niter: 400
37
  model_path: "liguang0115/cut3r"
38
  width: 512
39
  height: 288
 
42
  visualize: true
43
  visualize_pointcloud: false
44
  visualize_surfel: false
 
 
 
 
 
 
 
 
45
 
46
 
47
 
modeling/pipeline.py CHANGED
@@ -4,27 +4,19 @@ from copy import deepcopy
4
 
5
  import math
6
 
7
- # import matplotlib.pyplot as plt
8
- # from mpl_toolkits.mplot3d.art3d import Poly3DCollection
9
-
10
  import PIL
11
- from PIL import Image, ImageOps
12
  import numpy as np
13
  from einops import repeat
14
- # from scipy.spatial import cKDTree
15
 
16
  import torch
17
  import torch.nn.functional as F
18
- from torch.amp import autocast
19
  import torchvision.transforms as tvf
20
 
21
 
22
- # from diffusers import AutoencoderKL, DiffusionPipeline
23
- # from diffusers.schedulers import DDIMScheduler
24
  from diffusers.utils import export_to_gif
25
 
26
  import sys
27
- # Add CUT3R to Python path for imports
28
  sys.path.append("./extern/CUT3R")
29
  from extern.CUT3R.surfel_inference import run_inference_from_pil
30
  from extern.CUT3R.add_ckpt_path import add_path_to_dust3r
@@ -91,32 +83,23 @@ class VMemPipeline:
91
  self.device = device
92
 
93
 
94
- self.use_surfel = self.config.surfel.use_surfel
95
- if self.use_surfel:
96
- # Initialize CUT3R-based reconstructor
97
- # Load and prepare the model
98
- # download the model from huggingface
99
-
100
- surfel_model_path = hf_hub_download(repo_id=self.config.surfel.model_path, filename="cut3r_512_dpt_4_64.pth")
101
- print(f"Loading model from {surfel_model_path}...")
102
- add_path_to_dust3r(surfel_model_path)
103
- self.surfel_model = ARCroco3DStereo.from_pretrained(surfel_model_path).to(device)
104
- self.surfel_model.eval()
105
-
106
-
107
- # Import CUT3R scene alignment module
108
- from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode
109
- self.GlobalAlignerMode = GlobalAlignerMode
110
- self.global_aligner = global_aligner
111
 
112
 
113
 
114
- else:
115
- self.surfel_model = None
116
 
117
 
118
-
119
- self.temporal_only = self.config.model.temporal_only
120
  self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression
121
 
122
  self.context_num_frames = self.config.model.context_num_frames
@@ -537,33 +520,58 @@ class VMemPipeline:
537
  embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices]
538
  intrinsics = [self.Ks[i] for i in indices]
539
  return c2ws, latents, embeddings, intrinsics, indices
540
-
541
- if self.temporal_only:
542
- # Select frames based on timesteps (temporal mode)
543
- context_time_indices = [len(self.c2ws) - 1 - i for i in range(self.config.model.context_num_frames) if len(self.c2ws) - 1 - i >= 0]
544
- context_data = prepare_context_data(context_time_indices)
545
-
546
- elif not self.use_surfel:
547
- # Select frames based on camera pose distance with NMS
548
- average_c2w = average_camera_pose(target_c2ws)
549
- distances = torch.stack([self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype), torch.from_numpy(np.array(c2w)).to(self.device, self.dtype), weight_translation=self.config.model.translation_distance_weight)
550
- for c2w in self.c2ws])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
- # Sort frames by distance (closest to target first)
553
- sorted_indices = torch.argsort(distances)
554
- max_frames = min(self.config.model.context_num_frames, len(distances), len(self.latents))
555
 
556
- # Apply non-maximum suppression to select diverse frames
557
- is_first_step = len(self.pil_frames) <= 1
558
  is_second_step = len(self.pil_frames) == 5
559
- min_required_frames = 1 if is_first_step else max_frames
560
 
561
  # Adaptively determine initial threshold based on camera pose distribution
562
  if use_non_maximum_suppression is None:
563
  use_non_maximum_suppression = self.use_non_maximum_suppression
564
 
565
  if use_non_maximum_suppression:
566
-
567
  if is_second_step:
568
  # Calculate pairwise distances between existing frames
569
  pairwise_distances = []
@@ -581,32 +589,26 @@ class VMemPipeline:
581
  pairwise_distances.sort()
582
  percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
583
  self.initial_threshold = pairwise_distances[percentile_idx]
584
-
585
- # Ensure threshold is within reasonable bounds
586
- # initial_threshold = max(0.00, min(0.001, initial_threshold))
587
  else:
588
- self.initial_threshold = 0.001
589
- elif is_first_step:
590
- # Default threshold for first frame
591
- self.initial_threshold = 1e8
592
  else:
593
  self.initial_threshold = 1e8
594
-
595
-
596
 
597
  selected_indices = []
 
 
 
 
 
 
598
 
599
  # Try with increasingly relaxed thresholds until we get enough frames
600
- current_threshold = self.initial_threshold
601
- while len(selected_indices) < min_required_frames and current_threshold <= 1.0:
602
- # Reset selection with new threshold
603
- selected_indices = []
604
-
605
- # Always start with the closest pose
606
- selected_indices.append(sorted_indices[0])
607
-
608
  # Try to add each subsequent pose in order of distance
609
- for idx in sorted_indices[1:]:
610
  if len(selected_indices) >= max_frames:
611
  break
612
 
@@ -627,148 +629,22 @@ class VMemPipeline:
627
  selected_indices.append(idx)
628
 
629
  # If we still don't have enough frames, relax the threshold and try again
630
- if len(selected_indices) < min_required_frames:
631
- current_threshold *= 1.2
632
  else:
633
  break
634
 
635
  # If we still don't have enough frames, just take the top frames by distance
636
- if len(selected_indices) < min_required_frames:
637
  available_indices = []
638
- for idx in sorted_indices:
639
  if idx not in selected_indices:
640
  available_indices.append(idx)
641
- selected_indices.extend(available_indices[:min_required_frames-len(selected_indices)])
642
 
643
  # Convert to tensor and maintain original order (don't reverse)
644
- context_time_indices = torch.tensor(selected_indices, device=distances.device)
645
- context_data = prepare_context_data(context_time_indices)
646
-
647
- else:
648
- if len(self.pil_frames) == 1:
649
- context_time_indices = [0]
650
- else:
651
- # get the average camera pose
652
- average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:])
653
- transformed_average_c2w = self.get_transformed_c2ws(average_c2w)
654
- target_K = np.mean(self.surfel_Ks, axis=0)
655
- # Select frames using surfel-based relevance
656
- retrieved_info = self.render_surfels_to_image(
657
- self.surfels,
658
- transformed_average_c2w,
659
- [target_K*0.65] * 2,
660
- principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)),
661
- image_width=int(self.config.surfel.width),
662
- image_height=int(self.config.surfel.height)
663
- )
664
- _, frame_count = self.process_retrieved_spatial_information(retrieved_info)
665
- if self.config.inference.visualize:
666
- visualize_depth(retrieved_info["depth"],
667
- visualization_dir=self.visualize_dir,
668
- file_name=f"retrieved_depth_surfels.png",
669
- size=(self.width, self.height))
670
-
671
-
672
- # Build candidate frames based on relevance count
673
- candidates = []
674
- for frame, count in frame_count:
675
- candidates.extend([frame] * count)
676
- indices_to_frame = {
677
- i: frame for i, frame in enumerate(candidates)
678
- }
679
-
680
- # Sort candidates by distance to target view
681
- distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype),
682
- torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype),
683
- weight_translation=self.config.model.translation_distance_weight).item()
684
- for frame in candidates]
685
-
686
- sorted_indices = torch.argsort(torch.tensor(distances))
687
- sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices]
688
- max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents))
689
-
690
-
691
- is_second_step = len(self.pil_frames) == 5
692
-
693
-
694
- # Adaptively determine initial threshold based on camera pose distribution
695
- if use_non_maximum_suppression is None:
696
- use_non_maximum_suppression = self.use_non_maximum_suppression
697
-
698
- if use_non_maximum_suppression:
699
- if is_second_step:
700
- # Calculate pairwise distances between existing frames
701
- pairwise_distances = []
702
- for i in range(len(self.c2ws)):
703
- for j in range(i+1, len(self.c2ws)):
704
- sim = self.geodesic_distance(
705
- torch.from_numpy(np.array(self.c2ws[i])).to(self.device, self.dtype),
706
- torch.from_numpy(np.array(self.c2ws[j])).to(self.device, self.dtype),
707
- weight_translation=self.config.model.translation_distance_weight
708
- )
709
- pairwise_distances.append(sim.item())
710
-
711
- if pairwise_distances:
712
- # Sort distances and take percentile as threshold
713
- pairwise_distances.sort()
714
- percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
715
- self.initial_threshold = pairwise_distances[percentile_idx]
716
- else:
717
- self.initial_threshold = 1
718
-
719
-
720
-
721
- else:
722
- self.initial_threshold = 1e8
723
-
724
- selected_indices = []
725
- current_threshold = self.initial_threshold
726
-
727
- # Always start with the closest pose
728
- selected_indices.append(sorted_frames[0])
729
- if not use_non_maximum_suppression:
730
- selected_indices.append(len(self.c2ws) - 1)
731
-
732
- # Try with increasingly relaxed thresholds until we get enough frames
733
- while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression:
734
- # Try to add each subsequent pose in order of distance
735
- for idx in sorted_frames[1:]:
736
- if len(selected_indices) >= max_frames:
737
- break
738
-
739
- # Check if this candidate is sufficiently different from all selected frames
740
- is_too_similar = False
741
- for selected_idx in selected_indices:
742
- similarity = self.geodesic_distance(
743
- torch.from_numpy(np.array(self.c2ws[idx])).to(self.device, self.dtype),
744
- torch.from_numpy(np.array(self.c2ws[selected_idx])).to(self.device, self.dtype),
745
- weight_translation=self.config.model.translation_distance_weight
746
- )
747
- if similarity < current_threshold:
748
- is_too_similar = True
749
- break
750
-
751
- # Add to selected frames if not too similar to any existing selection
752
- if not is_too_similar:
753
- selected_indices.append(idx)
754
-
755
- # If we still don't have enough frames, relax the threshold and try again
756
- if len(selected_indices) < max_frames:
757
- current_threshold /= 1.2
758
- else:
759
- break
760
-
761
- # If we still don't have enough frames, just take the top frames by distance
762
- if len(selected_indices) < max_frames:
763
- available_indices = []
764
- for idx in sorted_frames:
765
- if idx not in selected_indices:
766
- available_indices.append(idx)
767
- selected_indices.extend(available_indices[:max_frames-len(selected_indices)])
768
-
769
- # Convert to tensor and maintain original order (don't reverse)
770
- context_time_indices = torch.from_numpy(np.array(selected_indices))
771
- context_data = prepare_context_data(context_time_indices)
772
 
773
  (context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data
774
  print(f"context_time_indices: {context_time_indices}")
@@ -992,11 +868,7 @@ class VMemPipeline:
992
  # Flip Y and Z components of camera poses to match dataset convention
993
  c2ws_transformed = self.get_transformed_c2ws()
994
 
995
- # Run inference to construct the scene
996
- if self.global_step == 10:
997
- visualize = True
998
- else:
999
- visualize = False
1000
  scene = run_inference_from_pil(
1001
  input_images,
1002
  self.surfel_model,
@@ -1004,8 +876,7 @@ class VMemPipeline:
1004
  depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None,
1005
  lr = lr,
1006
  niter = niter,
1007
- # visualize=self.config.inference.visualize_pointcloud,
1008
- visualize=visualize,
1009
  device=device,
1010
  )
1011
 
@@ -1043,12 +914,10 @@ class VMemPipeline:
1043
  )
1044
  confs = confs.squeeze(1)
1045
 
1046
- # self.surfels = []
1047
- # self.surfel_to_timestep = {}
1048
  start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames
1049
  end_idx = len(pointcloud)
1050
- # for frame_idx in range(len(pointcloud)):
1051
- # Create surfels for the current frame
1052
  for frame_idx in range(start_idx, end_idx):
1053
  surfels = self.pointmap_to_surfels(
1054
  pointmap=pointcloud[frame_idx],
@@ -1077,30 +946,6 @@ class VMemPipeline:
1077
  for surfel_index in range(num_surfels):
1078
  self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx]
1079
 
1080
- # Save surfels if configured
1081
- if self.config.inference.save_surfels and len(self.surfels) > 0:
1082
- positions = np.array([s.position for s in surfels], dtype=np.float32)
1083
- normals = np.array([s.normal for s in surfels], dtype=np.float32)
1084
- radii = np.array([s.radius for s in surfels], dtype=np.float32)
1085
- colors = np.array([s.color for s in surfels], dtype=np.float32)
1086
-
1087
- np.savez(f"{self.config.visualization_dir}/surfels_added.npz",
1088
- positions=positions,
1089
- normals=normals,
1090
- radii=radii,
1091
- colors=colors)
1092
-
1093
- positions = np.array([s.position for s in self.surfels], dtype=np.float32)
1094
- normals = np.array([s.normal for s in self.surfels], dtype=np.float32)
1095
- radii = np.array([s.radius for s in self.surfels], dtype=np.float32)
1096
- colors = np.array([s.color for s in self.surfels], dtype=np.float32)
1097
-
1098
- np.savez(f"{self.config.visualization_dir}/surfels_original.npz",
1099
- positions=positions,
1100
- normals=normals,
1101
- radii=radii,
1102
- colors=colors)
1103
-
1104
  self.surfels.extend(surfels)
1105
 
1106
  if self.config.inference.visualize_surfel:
@@ -1323,12 +1168,12 @@ class VMemPipeline:
1323
  self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png")
1324
 
1325
  # Update scene reconstruction if needed
1326
- if self.use_surfel and not self.temporal_only:
1327
- self.construct_and_store_scene(self.pil_frames,
1328
- time_indices=context_time_indices,
1329
- niter=self.config.surfel.niter,
1330
- lr=self.config.surfel.lr,
1331
- device=self.device)
1332
  self.global_step += 1
1333
 
1334
  if self.config.inference.visualize:
@@ -1386,9 +1231,9 @@ class VMemPipeline:
1386
 
1387
  # Handle surfels if using reconstructor
1388
  self.global_step -= frames_to_remove
1389
- if self.use_surfel:
1390
- for _ in range(frames_to_remove):
1391
- self.surfel_depths.pop()
1392
 
1393
 
1394
  # Find surfels that belong only to the removed timesteps
 
4
 
5
  import math
6
 
 
 
 
7
  import PIL
 
8
  import numpy as np
9
  from einops import repeat
10
+
11
 
12
  import torch
13
  import torch.nn.functional as F
 
14
  import torchvision.transforms as tvf
15
 
16
 
 
 
17
  from diffusers.utils import export_to_gif
18
 
19
  import sys
 
20
  sys.path.append("./extern/CUT3R")
21
  from extern.CUT3R.surfel_inference import run_inference_from_pil
22
  from extern.CUT3R.add_ckpt_path import add_path_to_dust3r
 
83
  self.device = device
84
 
85
 
86
+ surfel_model_path = hf_hub_download(repo_id=self.config.surfel.model_path, filename="cut3r_512_dpt_4_64.pth")
87
+ print(f"Loading model from {surfel_model_path}...")
88
+ add_path_to_dust3r(surfel_model_path)
89
+ self.surfel_model = ARCroco3DStereo.from_pretrained(surfel_model_path).to(device)
90
+ self.surfel_model.eval()
91
+
92
+
93
+ # Import CUT3R scene alignment module
94
+ from extern.CUT3R.cloud_opt.dust3r_opt import global_aligner, GlobalAlignerMode
95
+ self.GlobalAlignerMode = GlobalAlignerMode
96
+ self.global_aligner = global_aligner
 
 
 
 
 
 
97
 
98
 
99
 
100
+
 
101
 
102
 
 
 
103
  self.use_non_maximum_suppression = self.config.model.use_non_maximum_suppression
104
 
105
  self.context_num_frames = self.config.model.context_num_frames
 
520
  embeddings = [torch.from_numpy(self.encoder_embeddings[i]).to(self.device, self.dtype) for i in indices]
521
  intrinsics = [self.Ks[i] for i in indices]
522
  return c2ws, latents, embeddings, intrinsics, indices
523
+
524
+ if len(self.pil_frames) == 1:
525
+ context_time_indices = [0]
526
+ else:
527
+ # get the average camera pose
528
+ average_c2w = average_camera_pose(target_c2ws[-self.config.model.context_num_frames//4:])
529
+ transformed_average_c2w = self.get_transformed_c2ws(average_c2w)
530
+ target_K = np.mean(self.surfel_Ks, axis=0)
531
+ # Select frames using surfel-based relevance
532
+ retrieved_info = self.render_surfels_to_image(
533
+ self.surfels,
534
+ transformed_average_c2w,
535
+ [target_K*0.65] * 2,
536
+ principal_points=(int(self.config.surfel.width/2), int(self.config.surfel.height/2)),
537
+ image_width=int(self.config.surfel.width),
538
+ image_height=int(self.config.surfel.height)
539
+ )
540
+ _, frame_count = self.process_retrieved_spatial_information(retrieved_info)
541
+ if self.config.inference.visualize:
542
+ visualize_depth(retrieved_info["depth"],
543
+ visualization_dir=self.visualize_dir,
544
+ file_name=f"retrieved_depth_surfels.png",
545
+ size=(self.width, self.height))
546
+
547
+
548
+ # Build candidate frames based on relevance count
549
+ candidates = []
550
+ for frame, count in frame_count:
551
+ candidates.extend([frame] * count)
552
+ indices_to_frame = {
553
+ i: frame for i, frame in enumerate(candidates)
554
+ }
555
+
556
+ # Sort candidates by distance to target view
557
+ distances = [self.geodesic_distance(torch.from_numpy(average_c2w).to(self.device, self.dtype),
558
+ torch.from_numpy(self.c2ws[frame]).to(self.device, self.dtype),
559
+ weight_translation=self.config.model.translation_distance_weight).item()
560
+ for frame in candidates]
561
 
562
+ sorted_indices = torch.argsort(torch.tensor(distances))
563
+ sorted_frames = [indices_to_frame[int(i.item())] for i in sorted_indices]
564
+ max_frames = min(self.config.model.context_num_frames, len(candidates), len(self.latents))
565
 
566
+
 
567
  is_second_step = len(self.pil_frames) == 5
568
+
569
 
570
  # Adaptively determine initial threshold based on camera pose distribution
571
  if use_non_maximum_suppression is None:
572
  use_non_maximum_suppression = self.use_non_maximum_suppression
573
 
574
  if use_non_maximum_suppression:
 
575
  if is_second_step:
576
  # Calculate pairwise distances between existing frames
577
  pairwise_distances = []
 
589
  pairwise_distances.sort()
590
  percentile_idx = int(len(pairwise_distances) * 0.5) # 25th percentile
591
  self.initial_threshold = pairwise_distances[percentile_idx]
 
 
 
592
  else:
593
+ self.initial_threshold = 1
594
+
595
+
596
+
597
  else:
598
  self.initial_threshold = 1e8
 
 
599
 
600
  selected_indices = []
601
+ current_threshold = self.initial_threshold
602
+
603
+ # Always start with the closest pose
604
+ selected_indices.append(sorted_frames[0])
605
+ if not use_non_maximum_suppression:
606
+ selected_indices.append(len(self.c2ws) - 1)
607
 
608
  # Try with increasingly relaxed thresholds until we get enough frames
609
+ while len(selected_indices) < max_frames and current_threshold >= 1e-5 and use_non_maximum_suppression:
 
 
 
 
 
 
 
610
  # Try to add each subsequent pose in order of distance
611
+ for idx in sorted_frames[1:]:
612
  if len(selected_indices) >= max_frames:
613
  break
614
 
 
629
  selected_indices.append(idx)
630
 
631
  # If we still don't have enough frames, relax the threshold and try again
632
+ if len(selected_indices) < max_frames:
633
+ current_threshold /= 1.2
634
  else:
635
  break
636
 
637
  # If we still don't have enough frames, just take the top frames by distance
638
+ if len(selected_indices) < max_frames:
639
  available_indices = []
640
+ for idx in sorted_frames:
641
  if idx not in selected_indices:
642
  available_indices.append(idx)
643
+ selected_indices.extend(available_indices[:max_frames-len(selected_indices)])
644
 
645
  # Convert to tensor and maintain original order (don't reverse)
646
+ context_time_indices = torch.from_numpy(np.array(selected_indices))
647
+ context_data = prepare_context_data(context_time_indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
  (context_c2ws, context_latents, context_encoder_embeddings, context_Ks, context_time_indices) = context_data
650
  print(f"context_time_indices: {context_time_indices}")
 
868
  # Flip Y and Z components of camera poses to match dataset convention
869
  c2ws_transformed = self.get_transformed_c2ws()
870
 
871
+
 
 
 
 
872
  scene = run_inference_from_pil(
873
  input_images,
874
  self.surfel_model,
 
876
  depths=torch.from_numpy(np.array(self.surfel_depths)) if len(self.surfel_depths) > 0 else None,
877
  lr = lr,
878
  niter = niter,
879
+ visualize=self.config.inference.visualize_surfel,
 
880
  device=device,
881
  )
882
 
 
914
  )
915
  confs = confs.squeeze(1)
916
 
917
+
 
918
  start_idx = 0 if len(self.surfels) == 0 else len(pointcloud) - self.config.model.target_num_frames
919
  end_idx = len(pointcloud)
920
+
 
921
  for frame_idx in range(start_idx, end_idx):
922
  surfels = self.pointmap_to_surfels(
923
  pointmap=pointcloud[frame_idx],
 
946
  for surfel_index in range(num_surfels):
947
  self.surfel_to_timestep[surfel_start_index + surfel_index] = [frame_idx]
948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
949
  self.surfels.extend(surfels)
950
 
951
  if self.config.inference.visualize_surfel:
 
1168
  self.pil_frames[-1].save(f"{self.config.visualization_dir}/final_{len(self.pil_frames):07d}.png")
1169
 
1170
  # Update scene reconstruction if needed
1171
+
1172
+ self.construct_and_store_scene(self.pil_frames,
1173
+ time_indices=context_time_indices,
1174
+ niter=self.config.surfel.niter,
1175
+ lr=self.config.surfel.lr,
1176
+ device=self.device)
1177
  self.global_step += 1
1178
 
1179
  if self.config.inference.visualize:
 
1231
 
1232
  # Handle surfels if using reconstructor
1233
  self.global_step -= frames_to_remove
1234
+
1235
+ for _ in range(frames_to_remove):
1236
+ self.surfel_depths.pop()
1237
 
1238
 
1239
  # Find surfels that belong only to the removed timesteps