Junyi42 commited on
Commit
b6976aa
·
1 Parent(s): 54bbb7c
Files changed (2) hide show
  1. app.py +1 -1
  2. vis_st4rtrack.py +4 -4
app.py CHANGED
@@ -28,7 +28,7 @@ def check_ram_usage(threshold_percent=90):
28
  def main() -> None:
29
  # Load data once at startup using the function from vis_st4rtrack.py
30
  global global_data_cache
31
- global_data_cache = load_trajectory_data(use_float16=True, max_frames=120, traj_path="bonn_results", mask_folder="./train")
32
 
33
  app = fastapi.FastAPI()
34
  viser_manager = ViserProxyManager(app)
 
28
  def main() -> None:
29
  # Load data once at startup using the function from vis_st4rtrack.py
30
  global global_data_cache
31
+ global_data_cache = load_trajectory_data(use_float16=True, max_frames=120, traj_path="bonn_midanchor_ff", mask_folder="./train")
32
 
33
  app = fastapi.FastAPI()
34
  viser_manager = ViserProxyManager(app)
vis_st4rtrack.py CHANGED
@@ -28,7 +28,7 @@ def log_memory_usage(message=""):
28
  memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
29
  print(f"Memory usage {message}: {memory_mb:.2f} MB")
30
 
31
- def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train'):
32
  """Load trajectory data from files.
33
 
34
  Args:
@@ -128,7 +128,7 @@ def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None,
128
  # repeat the conf_head1 to match the number of frames in the dimension 0
129
  conf_head1 = np.tile(conf_head1, (num_frames, 1))
130
  # Convert to float32 before calculating percentile to avoid overflow
131
- conf_thre = np.percentile(conf_head1.astype(np.float32), 1) # Default percentile
132
  conf_mask_head1 = conf_head1 > conf_thre
133
  data_cache['conf_mask_head1'] = conf_mask_head1
134
 
@@ -152,7 +152,7 @@ def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None,
152
  conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
153
  conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
154
  # set conf thre to be 1 percentile of the conf_head2, for each frame
155
- conf_thre = np.percentile(conf_head2.astype(np.float32), 1, axis=1)
156
  conf_mask_head2 = conf_head2 > conf_thre[:, None]
157
  data_cache['conf_mask_head2'] = conf_mask_head2
158
 
@@ -226,7 +226,7 @@ def visualize_st4rtrack(
226
  else:
227
  # Load data using the shared function
228
  print("No preloaded data available, loading from files...")
229
- data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder)
230
  traj_3d_head1 = data.get('traj_3d_head1')
231
  traj_3d_head2 = data.get('traj_3d_head2')
232
  conf_mask_head1 = data.get('conf_mask_head1')
 
28
  memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
29
  print(f"Memory usage {message}: {memory_mb:.2f} MB")
30
 
31
+ def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train', conf_thre_percentile=10):
32
  """Load trajectory data from files.
33
 
34
  Args:
 
128
  # repeat the conf_head1 to match the number of frames in the dimension 0
129
  conf_head1 = np.tile(conf_head1, (num_frames, 1))
130
  # Convert to float32 before calculating percentile to avoid overflow
131
+ conf_thre = np.percentile(conf_head1.astype(np.float32), conf_thre_percentile) # Default percentile
132
  conf_mask_head1 = conf_head1 > conf_thre
133
  data_cache['conf_mask_head1'] = conf_mask_head1
134
 
 
152
  conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
153
  conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
154
  # set conf thre to be 1 percentile of the conf_head2, for each frame
155
+ conf_thre = np.percentile(conf_head2.astype(np.float32), conf_thre_percentile, axis=1)
156
  conf_mask_head2 = conf_head2 > conf_thre[:, None]
157
  data_cache['conf_mask_head2'] = conf_mask_head2
158
 
 
226
  else:
227
  # Load data using the shared function
228
  print("No preloaded data available, loading from files...")
229
+ data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder, conf_thre_percentile)
230
  traj_3d_head1 = data.get('traj_3d_head1')
231
  traj_3d_head2 = data.get('traj_3d_head2')
232
  conf_mask_head1 = data.get('conf_mask_head1')