Spaces:
Running
Running
code
Browse files- Dockerfile +22 -0
- README.md +30 -8
- app.py +105 -0
- requirements.txt +15 -0
- vis_st4rtrack.py +781 -0
- viser_proxy_manager.py +223 -0
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.12-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Install system dependencies for OpenCV and build tools
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
libgl1-mesa-glx \
|
8 |
+
libglib2.0-0 \
|
9 |
+
build-essential \
|
10 |
+
git \
|
11 |
+
&& rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
COPY requirements.txt .
|
14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
15 |
+
|
16 |
+
COPY . .
|
17 |
+
|
18 |
+
# Make port 7860 available to the world outside the container
|
19 |
+
EXPOSE 7860
|
20 |
+
|
21 |
+
# Command to run when the container starts
|
22 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,12 +1,34 @@
|
|
1 |
---
|
2 |
-
title: Viser
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Viser Gradio Embed
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
|
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# Viser + Gradio
|
12 |
+
|
13 |
+
Demo for integrating [viser](https://github.com/nerfstudio-project/viser) 3D
|
14 |
+
visualizations into a [Gradio](https://www.gradio.app/) application.
|
15 |
+
|
16 |
+
- Uses Gradio's session management to create isolated 3D visualization contexts.
|
17 |
+
- Exposes both Gradio and Viser over the same port.
|
18 |
+
|
19 |
+
## Deploying on HuggingFace Spaces
|
20 |
+
|
21 |
+
**[ [Live example](https://huggingface.co/spaces/brentyi/viser-gradio-embed) ]**
|
22 |
+
|
23 |
+
This repository should work out-of-the-box with HF Spaces via Docker.
|
24 |
+
|
25 |
+
- Unlike a vanilla Gradio Space, this is unfortunately not supported by [ZeroGPU](https://huggingface.co/docs/hub/en/spaces-zerogpu).
|
26 |
+
|
27 |
+
## Local Demo
|
28 |
+
|
29 |
+
```bash
|
30 |
+
pip install -r requirements.txt
|
31 |
+
python app.py
|
32 |
+
```
|
33 |
+
|
34 |
+
https://github.com/user-attachments/assets/b94a117a-b9e5-4854-805a-8666941c7816
|
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import threading
|
3 |
+
import psutil
|
4 |
+
import fastapi
|
5 |
+
import gradio as gr
|
6 |
+
import uvicorn
|
7 |
+
|
8 |
+
from viser_proxy_manager import ViserProxyManager
|
9 |
+
from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage
|
10 |
+
|
11 |
+
# Global cache for loaded data
|
12 |
+
global_data_cache = None
|
13 |
+
|
14 |
+
def check_ram_usage(threshold_percent=90):
|
15 |
+
"""Check if RAM usage is above the threshold.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
threshold_percent: Maximum RAM usage percentage allowed
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
bool: True if RAM usage is below threshold, False otherwise
|
22 |
+
"""
|
23 |
+
ram_percent = psutil.virtual_memory().percent
|
24 |
+
print(f"Current RAM usage: {ram_percent}%")
|
25 |
+
return ram_percent < threshold_percent
|
26 |
+
|
27 |
+
|
28 |
+
def main() -> None:
|
29 |
+
# Load data once at startup using the function from vis_st4rtrack.py
|
30 |
+
global global_data_cache
|
31 |
+
global_data_cache = load_trajectory_data(use_float16=True, max_frames=120, traj_path="bonn_results", mask_folder="./train")
|
32 |
+
|
33 |
+
app = fastapi.FastAPI()
|
34 |
+
viser_manager = ViserProxyManager(app)
|
35 |
+
|
36 |
+
# Create a Gradio interface with title, iframe, and buttons
|
37 |
+
with gr.Blocks(title="Viser Viewer") as demo:
|
38 |
+
# Add the iframe with a border
|
39 |
+
iframe_html = gr.HTML("")
|
40 |
+
status_text = gr.Markdown("") # Add status text component
|
41 |
+
|
42 |
+
@demo.load(outputs=[iframe_html, status_text])
|
43 |
+
def start_server(request: gr.Request):
|
44 |
+
assert request.session_hash is not None
|
45 |
+
|
46 |
+
# Check RAM usage before starting visualization
|
47 |
+
if not check_ram_usage(threshold_percent=100):
|
48 |
+
return """
|
49 |
+
<div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
|
50 |
+
<h2>⚠️ Server is currently under high load</h2>
|
51 |
+
<p>Please try again later when resources are available.</p>
|
52 |
+
</div>
|
53 |
+
""", "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
|
54 |
+
|
55 |
+
viser_manager.start_server(request.session_hash)
|
56 |
+
|
57 |
+
# Use the request's base URL if available
|
58 |
+
host = request.headers["host"]
|
59 |
+
|
60 |
+
# Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
|
61 |
+
protocol = (
|
62 |
+
"https"
|
63 |
+
if request.headers.get("x-forwarded-proto") == "https"
|
64 |
+
else "http"
|
65 |
+
)
|
66 |
+
|
67 |
+
# Add visualization in a separate thread
|
68 |
+
server = viser_manager.get_server(request.session_hash)
|
69 |
+
threading.Thread(
|
70 |
+
target=visualize_st4rtrack,
|
71 |
+
kwargs={
|
72 |
+
"server": server,
|
73 |
+
"use_float16": True,
|
74 |
+
"preloaded_data": global_data_cache, # Pass the preloaded data
|
75 |
+
"color_code": "jet",
|
76 |
+
"blue_rgb": (0.0, 0.149, 0.463), # #002676
|
77 |
+
"red_rgb": (0.769, 0.510, 0.055), # #FDB515
|
78 |
+
"blend_ratio": 0.7
|
79 |
+
},
|
80 |
+
daemon=True
|
81 |
+
).start()
|
82 |
+
|
83 |
+
return f"""
|
84 |
+
<iframe
|
85 |
+
src="{protocol}://{host}/viser/{request.session_hash}/"
|
86 |
+
width="100%"
|
87 |
+
height="500px"
|
88 |
+
frameborder="0"
|
89 |
+
style="display: block;"
|
90 |
+
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
91 |
+
loading="lazy"
|
92 |
+
></iframe>
|
93 |
+
""", "**System Status:** Visualization loaded successfully."
|
94 |
+
|
95 |
+
@demo.unload
|
96 |
+
def stop(request: gr.Request):
|
97 |
+
assert request.session_hash is not None
|
98 |
+
viser_manager.stop_server(request.session_hash)
|
99 |
+
|
100 |
+
gr.mount_gradio_app(app, demo, "/")
|
101 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# git+https://github.com/nerfstudio-project/viser.git
|
2 |
+
viser>=0.2.23
|
3 |
+
gradio==5.23.1
|
4 |
+
fastapi==0.115.11
|
5 |
+
uvicorn==0.34.0
|
6 |
+
httpx==0.27.2
|
7 |
+
websockets==15.0.1
|
8 |
+
tyro==0.4.1
|
9 |
+
numpy>=1.20.0
|
10 |
+
tqdm>=4.62.0
|
11 |
+
opencv-python>=4.5.0
|
12 |
+
imageio>=2.25.0
|
13 |
+
matplotlib>=3.5.0
|
14 |
+
pyliblzfse>=0.1.0
|
15 |
+
psutil>=5.9.0
|
vis_st4rtrack.py
ADDED
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Record3D visualizer
|
2 |
+
|
3 |
+
Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import time
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import numpy as onp
|
10 |
+
import tyro
|
11 |
+
import cv2
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
|
14 |
+
import viser
|
15 |
+
import viser.extras
|
16 |
+
import viser.transforms as tf
|
17 |
+
|
18 |
+
from glob import glob
|
19 |
+
import numpy as np
|
20 |
+
import imageio.v3 as iio
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
import psutil
|
23 |
+
|
24 |
+
def log_memory_usage(message=""):
|
25 |
+
"""Log current memory usage with an optional message."""
|
26 |
+
process = psutil.Process()
|
27 |
+
memory_info = process.memory_info()
|
28 |
+
memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
|
29 |
+
print(f"Memory usage {message}: {memory_mb:.2f} MB")
|
30 |
+
|
31 |
+
def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train'):
|
32 |
+
"""Load trajectory data from files.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
traj_path: Path to the directory containing trajectory data
|
36 |
+
use_float16: Whether to convert data to float16 to save memory
|
37 |
+
max_frames: Maximum number of frames to load (None for all)
|
38 |
+
mask_folder: Path to the directory containing mask images
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
A dictionary containing loaded data
|
42 |
+
"""
|
43 |
+
log_memory_usage("before loading data")
|
44 |
+
|
45 |
+
data_cache = {
|
46 |
+
'traj_3d_head1': None,
|
47 |
+
'traj_3d_head2': None,
|
48 |
+
'conf_mask_head1': None,
|
49 |
+
'conf_mask_head2': None,
|
50 |
+
'masks': None,
|
51 |
+
'raw_video': None,
|
52 |
+
'loaded': False
|
53 |
+
}
|
54 |
+
|
55 |
+
# Load masks
|
56 |
+
masks_paths = sorted(glob(mask_folder + '/*.jpg'))
|
57 |
+
masks = None
|
58 |
+
|
59 |
+
if masks_paths:
|
60 |
+
masks = [iio.imread(p) for p in masks_paths]
|
61 |
+
masks = np.stack(masks, axis=0)
|
62 |
+
# Convert masks to binary (0 or 1)
|
63 |
+
masks = (masks < 1).astype(np.float32)
|
64 |
+
masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
|
65 |
+
print(f"Original masks shape: {masks.shape}")
|
66 |
+
else:
|
67 |
+
print("No masks found. Will create default masks when needed.")
|
68 |
+
|
69 |
+
data_cache['masks'] = masks
|
70 |
+
|
71 |
+
if Path(traj_path).is_dir():
|
72 |
+
# Find all trajectory files
|
73 |
+
traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
|
74 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
75 |
+
conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
|
76 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
77 |
+
|
78 |
+
traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
|
79 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
80 |
+
conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
|
81 |
+
key=lambda x: int(x.split('_p')[-1].split('.')[0]))
|
82 |
+
|
83 |
+
# Limit number of frames if specified
|
84 |
+
if max_frames is not None:
|
85 |
+
traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames]
|
86 |
+
conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else []
|
87 |
+
traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames]
|
88 |
+
conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else []
|
89 |
+
|
90 |
+
# Process head1
|
91 |
+
if traj_3d_paths_head1:
|
92 |
+
if use_float16:
|
93 |
+
traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
|
94 |
+
else:
|
95 |
+
traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
|
96 |
+
|
97 |
+
log_memory_usage("after loading head1 data")
|
98 |
+
|
99 |
+
h, w, _ = traj_3d_head1.shape[1:]
|
100 |
+
num_frames = traj_3d_head1.shape[0]
|
101 |
+
|
102 |
+
# If masks is None, create default masks (all ones)
|
103 |
+
if masks is None:
|
104 |
+
masks = np.ones((num_frames, h, w), dtype=bool)
|
105 |
+
print(f"Created default masks with shape: {masks.shape}")
|
106 |
+
data_cache['masks'] = masks
|
107 |
+
else:
|
108 |
+
# Resize masks to match trajectory dimensions using nearest neighbor interpolation
|
109 |
+
masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
|
110 |
+
for i in range(masks.shape[0]):
|
111 |
+
masks_resized[i] = cv2.resize(
|
112 |
+
masks[i].astype(np.uint8),
|
113 |
+
(w, h),
|
114 |
+
interpolation=cv2.INTER_NEAREST
|
115 |
+
).astype(bool)
|
116 |
+
|
117 |
+
print(f"Resized masks shape: {masks_resized.shape}")
|
118 |
+
data_cache['masks'] = masks_resized
|
119 |
+
|
120 |
+
# Reshape trajectory data
|
121 |
+
traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
|
122 |
+
data_cache['traj_3d_head1'] = traj_3d_head1
|
123 |
+
|
124 |
+
if conf_paths_head1:
|
125 |
+
conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
|
126 |
+
conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
|
127 |
+
conf_head1 = conf_head1.mean(axis=0)
|
128 |
+
# repeat the conf_head1 to match the number of frames in the dimension 0
|
129 |
+
conf_head1 = np.tile(conf_head1, (num_frames, 1))
|
130 |
+
# Convert to float32 before calculating percentile to avoid overflow
|
131 |
+
conf_thre = np.percentile(conf_head1.astype(np.float32), 1) # Default percentile
|
132 |
+
conf_mask_head1 = conf_head1 > conf_thre
|
133 |
+
data_cache['conf_mask_head1'] = conf_mask_head1
|
134 |
+
|
135 |
+
# Process head2
|
136 |
+
if traj_3d_paths_head2:
|
137 |
+
if use_float16:
|
138 |
+
traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
|
139 |
+
else:
|
140 |
+
traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
|
141 |
+
|
142 |
+
log_memory_usage("after loading head2 data")
|
143 |
+
|
144 |
+
# Store raw video data
|
145 |
+
raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
|
146 |
+
data_cache['raw_video'] = raw_video
|
147 |
+
|
148 |
+
traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
|
149 |
+
data_cache['traj_3d_head2'] = traj_3d_head2
|
150 |
+
|
151 |
+
if conf_paths_head2:
|
152 |
+
conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
|
153 |
+
conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
|
154 |
+
# set conf thre to be 1 percentile of the conf_head2, for each frame
|
155 |
+
conf_thre = np.percentile(conf_head2.astype(np.float32), 1, axis=1)
|
156 |
+
conf_mask_head2 = conf_head2 > conf_thre[:, None]
|
157 |
+
data_cache['conf_mask_head2'] = conf_mask_head2
|
158 |
+
|
159 |
+
data_cache['loaded'] = True
|
160 |
+
log_memory_usage("after loading all data")
|
161 |
+
return data_cache
|
162 |
+
|
163 |
+
def visualize_st4rtrack(
|
164 |
+
traj_path: str = "results",
|
165 |
+
up_dir: str = "-z", # should be +z or -z
|
166 |
+
max_frames: int = 100,
|
167 |
+
share: bool = False,
|
168 |
+
point_size: float = 0.005,
|
169 |
+
downsample_factor: int = 3,
|
170 |
+
num_traj_points: int = 100,
|
171 |
+
conf_thre_percentile: float = 1,
|
172 |
+
traj_end_frame: int = 100,
|
173 |
+
traj_start_frame: int = 0,
|
174 |
+
traj_line_width: float = 3.,
|
175 |
+
fixed_length_traj: int = 20,
|
176 |
+
server: viser.ViserServer = None,
|
177 |
+
use_float16: bool = True,
|
178 |
+
preloaded_data: dict = None, # Add this parameter to accept preloaded data
|
179 |
+
color_code: str = "jet",
|
180 |
+
# Updated hex colors: #002676 for blue and #FDB515 for red/gold
|
181 |
+
blue_rgb: tuple[float, float, float] = (0.0, 0.149, 0.463), # #002676
|
182 |
+
red_rgb: tuple[float, float, float] = (0.769, 0.510, 0.055), # #FDB515
|
183 |
+
blend_ratio: float = 0.7,
|
184 |
+
mask_folder: str = None,
|
185 |
+
mid_anchor: bool = False,
|
186 |
+
video_width: int = 320, # Video display width
|
187 |
+
video_height: int = 180, # Video display height
|
188 |
+
camera_position: tuple[float, float, float] = (1e-3, 1.5, -0.2),
|
189 |
+
) -> None:
|
190 |
+
log_memory_usage("at start of visualization")
|
191 |
+
|
192 |
+
if server is None:
|
193 |
+
server = viser.ViserServer()
|
194 |
+
if share:
|
195 |
+
server.request_share_url()
|
196 |
+
|
197 |
+
@server.on_client_connect
|
198 |
+
def _(client: viser.ClientHandle) -> None:
|
199 |
+
client.camera.position = camera_position
|
200 |
+
client.camera.look_at = (0, 0, 0)
|
201 |
+
|
202 |
+
# Configure the GUI panel size and layout
|
203 |
+
server.gui.configure_theme(
|
204 |
+
control_layout="collapsible",
|
205 |
+
control_width="small",
|
206 |
+
dark_mode=False,
|
207 |
+
show_logo=False,
|
208 |
+
show_share_button=True
|
209 |
+
)
|
210 |
+
|
211 |
+
# Add video preview to the GUI panel - placed at the top
|
212 |
+
video_preview = server.gui.add_image(
|
213 |
+
np.zeros((video_height, video_width, 3), dtype=np.uint8), # Initial blank image
|
214 |
+
format="jpeg"
|
215 |
+
)
|
216 |
+
|
217 |
+
# Use preloaded data if available
|
218 |
+
if preloaded_data and preloaded_data.get('loaded', False):
|
219 |
+
traj_3d_head1 = preloaded_data.get('traj_3d_head1')
|
220 |
+
traj_3d_head2 = preloaded_data.get('traj_3d_head2')
|
221 |
+
conf_mask_head1 = preloaded_data.get('conf_mask_head1')
|
222 |
+
conf_mask_head2 = preloaded_data.get('conf_mask_head2')
|
223 |
+
masks = preloaded_data.get('masks')
|
224 |
+
raw_video = preloaded_data.get('raw_video')
|
225 |
+
print("Using preloaded data!")
|
226 |
+
else:
|
227 |
+
# Load data using the shared function
|
228 |
+
print("No preloaded data available, loading from files...")
|
229 |
+
data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder)
|
230 |
+
traj_3d_head1 = data.get('traj_3d_head1')
|
231 |
+
traj_3d_head2 = data.get('traj_3d_head2')
|
232 |
+
conf_mask_head1 = data.get('conf_mask_head1')
|
233 |
+
conf_mask_head2 = data.get('conf_mask_head2')
|
234 |
+
masks = data.get('masks')
|
235 |
+
raw_video = data.get('raw_video')
|
236 |
+
|
237 |
+
def process_video_frame(frame_idx):
|
238 |
+
if raw_video is None:
|
239 |
+
return np.zeros((video_height, video_width, 3), dtype=np.uint8)
|
240 |
+
|
241 |
+
# Get the original frame
|
242 |
+
raw_frame = raw_video[frame_idx]
|
243 |
+
|
244 |
+
# Adjust value range to 0-255
|
245 |
+
if raw_frame.max() <= 1.0:
|
246 |
+
frame = (raw_frame * 255).astype(np.uint8)
|
247 |
+
else:
|
248 |
+
frame = raw_frame.astype(np.uint8)
|
249 |
+
|
250 |
+
# Resize to fit the preview window
|
251 |
+
h, w = frame.shape[:2]
|
252 |
+
# Calculate size while maintaining aspect ratio
|
253 |
+
if h/w > video_height/video_width: # Height limited
|
254 |
+
new_h = video_height
|
255 |
+
new_w = int(w * (new_h / h))
|
256 |
+
else: # Width limited
|
257 |
+
new_w = video_width
|
258 |
+
new_h = int(h * (new_w / w))
|
259 |
+
|
260 |
+
# Resize
|
261 |
+
resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
262 |
+
|
263 |
+
# Create a black background
|
264 |
+
display_frame = np.zeros((video_height, video_width, 3), dtype=np.uint8)
|
265 |
+
|
266 |
+
# Place the resized frame in the center
|
267 |
+
y_offset = (video_height - new_h) // 2
|
268 |
+
x_offset = (video_width - new_w) // 2
|
269 |
+
display_frame[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_frame
|
270 |
+
|
271 |
+
return display_frame
|
272 |
+
|
273 |
+
server.scene.set_up_direction(up_dir)
|
274 |
+
print("Setting up visualization!")
|
275 |
+
|
276 |
+
# Add visualization controls
|
277 |
+
with server.gui.add_folder("Visualization"):
|
278 |
+
gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
|
279 |
+
gui_show_head2 = server.gui.add_checkbox("Recon Points", True)
|
280 |
+
gui_show_trajectories = server.gui.add_checkbox("Trajectories", True)
|
281 |
+
gui_use_color_tint = server.gui.add_checkbox("Use Color Tint", True)
|
282 |
+
|
283 |
+
# Process and center point clouds
|
284 |
+
center_point = None
|
285 |
+
if traj_3d_head1 is not None:
|
286 |
+
xyz_head1 = traj_3d_head1[:, :, :3]
|
287 |
+
rgb_head1 = traj_3d_head1[:, :, 3:6]
|
288 |
+
if center_point is None:
|
289 |
+
center_point = onp.mean(xyz_head1, axis=(0, 1), keepdims=True)
|
290 |
+
xyz_head1 -= center_point
|
291 |
+
if rgb_head1.sum(axis=(-1)).max() > 125:
|
292 |
+
rgb_head1 /= 255.0
|
293 |
+
|
294 |
+
if traj_3d_head2 is not None:
|
295 |
+
xyz_head2 = traj_3d_head2[:, :, :3]
|
296 |
+
rgb_head2 = traj_3d_head2[:, :, 3:6]
|
297 |
+
if center_point is None:
|
298 |
+
center_point = onp.mean(xyz_head2, axis=(0, 1), keepdims=True)
|
299 |
+
xyz_head2 -= center_point
|
300 |
+
if rgb_head2.sum(axis=(-1)).max() > 125:
|
301 |
+
rgb_head2 /= 255.0
|
302 |
+
|
303 |
+
# Determine number of frames
|
304 |
+
F = max(
|
305 |
+
traj_3d_head1.shape[0] if traj_3d_head1 is not None else 0,
|
306 |
+
traj_3d_head2.shape[0] if traj_3d_head2 is not None else 0
|
307 |
+
)
|
308 |
+
num_frames = min(max_frames, F)
|
309 |
+
traj_end_frame = min(traj_end_frame, num_frames)
|
310 |
+
print(f"Number of frames: {num_frames}")
|
311 |
+
xyz_head1 = xyz_head1[:num_frames]
|
312 |
+
xyz_head2 = xyz_head2[:num_frames]
|
313 |
+
rgb_head1 = rgb_head1[:num_frames]
|
314 |
+
rgb_head2 = rgb_head2[:num_frames]
|
315 |
+
|
316 |
+
# Add playback UI.
|
317 |
+
with server.gui.add_folder("Playback"):
|
318 |
+
gui_timestep = server.gui.add_slider(
|
319 |
+
"Timestep",
|
320 |
+
min=0,
|
321 |
+
max=num_frames - 1,
|
322 |
+
step=1,
|
323 |
+
initial_value=0,
|
324 |
+
disabled=True,
|
325 |
+
)
|
326 |
+
gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
|
327 |
+
gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
|
328 |
+
gui_playing = server.gui.add_checkbox("Playing", True)
|
329 |
+
gui_framerate = server.gui.add_slider(
|
330 |
+
"FPS", min=1, max=60, step=0.1, initial_value=20
|
331 |
+
)
|
332 |
+
gui_framerate_options = server.gui.add_button_group(
|
333 |
+
"FPS options", ("10", "20", "30")
|
334 |
+
)
|
335 |
+
gui_show_all_frames = server.gui.add_checkbox("Show all frames", False)
|
336 |
+
gui_stride = server.gui.add_slider(
|
337 |
+
"Stride",
|
338 |
+
min=1,
|
339 |
+
max=num_frames,
|
340 |
+
step=1,
|
341 |
+
initial_value=5,
|
342 |
+
disabled=True, # Initially disabled
|
343 |
+
)
|
344 |
+
|
345 |
+
# Frame step buttons.
|
346 |
+
@gui_next_frame.on_click
|
347 |
+
def _(_) -> None:
|
348 |
+
gui_timestep.value = (gui_timestep.value + 1) % num_frames
|
349 |
+
|
350 |
+
@gui_prev_frame.on_click
|
351 |
+
def _(_) -> None:
|
352 |
+
gui_timestep.value = (gui_timestep.value - 1) % num_frames
|
353 |
+
|
354 |
+
# Disable frame controls when we're playing.
|
355 |
+
@gui_playing.on_update
|
356 |
+
def _(_) -> None:
|
357 |
+
gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value
|
358 |
+
gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value
|
359 |
+
gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value
|
360 |
+
|
361 |
+
# Set the framerate when we click one of the options.
|
362 |
+
@gui_framerate_options.on_click
|
363 |
+
def _(_) -> None:
|
364 |
+
gui_framerate.value = int(gui_framerate_options.value)
|
365 |
+
|
366 |
+
prev_timestep = gui_timestep.value
|
367 |
+
|
368 |
+
# Toggle frame visibility when the timestep slider changes.
|
369 |
+
@gui_timestep.on_update
|
370 |
+
def _(_) -> None:
|
371 |
+
nonlocal prev_timestep
|
372 |
+
current_timestep = gui_timestep.value
|
373 |
+
if not gui_show_all_frames.value:
|
374 |
+
with server.atomic():
|
375 |
+
if gui_show_head1.value:
|
376 |
+
frame_nodes_head1[current_timestep].visible = True
|
377 |
+
frame_nodes_head1[prev_timestep].visible = False
|
378 |
+
if gui_show_head2.value:
|
379 |
+
frame_nodes_head2[current_timestep].visible = True
|
380 |
+
frame_nodes_head2[prev_timestep].visible = False
|
381 |
+
prev_timestep = current_timestep
|
382 |
+
server.flush() # Optional!
|
383 |
+
|
384 |
+
# Show or hide all frames based on the checkbox.
|
385 |
+
@gui_show_all_frames.on_update
|
386 |
+
def _(_) -> None:
|
387 |
+
gui_stride.disabled = not gui_show_all_frames.value # Enable/disable stride slider
|
388 |
+
if gui_show_all_frames.value:
|
389 |
+
# Show frames with stride
|
390 |
+
stride = gui_stride.value
|
391 |
+
with server.atomic():
|
392 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
393 |
+
node1.visible = gui_show_head1.value and (i % stride == 0)
|
394 |
+
node2.visible = gui_show_head2.value and (i % stride == 0)
|
395 |
+
# Disable playback controls
|
396 |
+
gui_playing.disabled = True
|
397 |
+
gui_timestep.disabled = True
|
398 |
+
gui_next_frame.disabled = True
|
399 |
+
gui_prev_frame.disabled = True
|
400 |
+
else:
|
401 |
+
# Show only the current frame
|
402 |
+
current_timestep = gui_timestep.value
|
403 |
+
with server.atomic():
|
404 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
405 |
+
node1.visible = gui_show_head1.value and (i == current_timestep)
|
406 |
+
node2.visible = gui_show_head2.value and (i == current_timestep)
|
407 |
+
# Re-enable playback controls
|
408 |
+
gui_playing.disabled = False
|
409 |
+
gui_timestep.disabled = gui_playing.value
|
410 |
+
gui_next_frame.disabled = gui_playing.value
|
411 |
+
gui_prev_frame.disabled = gui_playing.value
|
412 |
+
|
413 |
+
# Update frame visibility when the stride changes.
|
414 |
+
@gui_stride.on_update
|
415 |
+
def _(_) -> None:
|
416 |
+
if gui_show_all_frames.value:
|
417 |
+
# Update frame visibility based on new stride
|
418 |
+
stride = gui_stride.value
|
419 |
+
with server.atomic():
|
420 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
421 |
+
node1.visible = gui_show_head1.value and (i % stride == 0)
|
422 |
+
node2.visible = gui_show_head2.value and (i % stride == 0)
|
423 |
+
|
424 |
+
# Load in frames.
|
425 |
+
server.scene.add_frame(
|
426 |
+
"/frames",
|
427 |
+
wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz,
|
428 |
+
position=(0, 0, 0),
|
429 |
+
show_axes=False,
|
430 |
+
)
|
431 |
+
frame_nodes_head1: list[viser.FrameHandle] = []
|
432 |
+
frame_nodes_head2: list[viser.FrameHandle] = []
|
433 |
+
|
434 |
+
# Extract RGB components for tinting
|
435 |
+
blue_r, blue_g, blue_b = blue_rgb
|
436 |
+
red_r, red_g, red_b = red_rgb
|
437 |
+
|
438 |
+
# Create frames for each timestep
|
439 |
+
frame_nodes_head1 = []
|
440 |
+
frame_nodes_head2 = []
|
441 |
+
for i in tqdm(range(num_frames)):
|
442 |
+
# Process head1
|
443 |
+
if traj_3d_head1 is not None:
|
444 |
+
frame_nodes_head1.append(server.scene.add_frame(f"/frames/t{i}/head1", show_axes=False))
|
445 |
+
position = xyz_head1[i]
|
446 |
+
color = rgb_head1[i]
|
447 |
+
if conf_mask_head1 is not None:
|
448 |
+
position = position[conf_mask_head1[i]]
|
449 |
+
color = color[conf_mask_head1[i]]
|
450 |
+
|
451 |
+
# Add point cloud for head1 with optional blue tint
|
452 |
+
color_head1 = color.copy()
|
453 |
+
if gui_use_color_tint.value:
|
454 |
+
color_head1 *= blend_ratio
|
455 |
+
color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
|
456 |
+
color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
|
457 |
+
color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
|
458 |
+
|
459 |
+
server.scene.add_point_cloud(
|
460 |
+
name=f"/frames/t{i}/head1/point_cloud",
|
461 |
+
points=position[::downsample_factor],
|
462 |
+
colors=color_head1[::downsample_factor],
|
463 |
+
point_size=point_size,
|
464 |
+
point_shape="rounded",
|
465 |
+
)
|
466 |
+
|
467 |
+
# Process head2
|
468 |
+
if traj_3d_head2 is not None:
|
469 |
+
frame_nodes_head2.append(server.scene.add_frame(f"/frames/t{i}/head2", show_axes=False))
|
470 |
+
position = xyz_head2[i]
|
471 |
+
color = rgb_head2[i]
|
472 |
+
if conf_mask_head2 is not None:
|
473 |
+
position = position[conf_mask_head2[i]]
|
474 |
+
color = color[conf_mask_head2[i]]
|
475 |
+
|
476 |
+
# Add point cloud for head2 with optional red tint
|
477 |
+
color_head2 = color.copy()
|
478 |
+
if gui_use_color_tint.value:
|
479 |
+
color_head2 *= blend_ratio
|
480 |
+
color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
|
481 |
+
color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
|
482 |
+
color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
|
483 |
+
|
484 |
+
server.scene.add_point_cloud(
|
485 |
+
name=f"/frames/t{i}/head2/point_cloud",
|
486 |
+
points=position[::downsample_factor],
|
487 |
+
colors=color_head2[::downsample_factor],
|
488 |
+
point_size=point_size,
|
489 |
+
point_shape="rounded",
|
490 |
+
)
|
491 |
+
|
492 |
+
# Update visibility based on checkboxes
|
493 |
+
@gui_show_head1.on_update
|
494 |
+
def _(_) -> None:
|
495 |
+
with server.atomic():
|
496 |
+
for frame_node in frame_nodes_head1:
|
497 |
+
frame_node.visible = gui_show_head1.value and (
|
498 |
+
gui_show_all_frames.value
|
499 |
+
or (not gui_show_all_frames.value )
|
500 |
+
)
|
501 |
+
|
502 |
+
@gui_show_head2.on_update
|
503 |
+
def _(_) -> None:
|
504 |
+
with server.atomic():
|
505 |
+
for frame_node in frame_nodes_head2:
|
506 |
+
frame_node.visible = gui_show_head2.value and (
|
507 |
+
gui_show_all_frames.value
|
508 |
+
or (not gui_show_all_frames.value )
|
509 |
+
)
|
510 |
+
|
511 |
+
# Initial visibility
|
512 |
+
for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
|
513 |
+
if gui_show_all_frames.value:
|
514 |
+
node1.visible = gui_show_head1.value and (i % gui_stride.value == 0)
|
515 |
+
node2.visible = gui_show_head2.value and (i % gui_stride.value == 0)
|
516 |
+
else:
|
517 |
+
node1.visible = gui_show_head1.value and (i == gui_timestep.value)
|
518 |
+
node2.visible = gui_show_head2.value and (i == gui_timestep.value)
|
519 |
+
|
520 |
+
# Process and visualize trajectories for head1
|
521 |
+
if traj_3d_head1 is not None:
|
522 |
+
# Get points over time
|
523 |
+
xyz_head1_centered = xyz_head1.copy()
|
524 |
+
|
525 |
+
# Select points to visualize
|
526 |
+
num_points = xyz_head1.shape[1]
|
527 |
+
points_to_visualize = min(num_points, num_traj_points)
|
528 |
+
|
529 |
+
# Get the mask for the first frame and reshape it to match point cloud dimensions
|
530 |
+
if mid_anchor:
|
531 |
+
first_frame_mask = masks[num_frames//2].reshape(-1)
|
532 |
+
else:
|
533 |
+
first_frame_mask = masks[0].reshape(-1) #[#points, h]
|
534 |
+
|
535 |
+
# Calculate trajectory lengths for each point
|
536 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame] # Shape: (num_frames, num_points, 3)
|
537 |
+
traj_diffs = np.diff(trajectories, axis=0) # Differences between consecutive frames
|
538 |
+
traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0) # Sum of distances for each point
|
539 |
+
|
540 |
+
# Get points that are within the mask
|
541 |
+
valid_indices = np.where(first_frame_mask)[0]
|
542 |
+
|
543 |
+
if len(valid_indices) > 0:
|
544 |
+
# Calculate average trajectory length for masked points
|
545 |
+
masked_traj_lengths = traj_lengths[valid_indices]
|
546 |
+
avg_traj_length = np.mean(masked_traj_lengths)
|
547 |
+
|
548 |
+
if mask_folder is not None:
|
549 |
+
# do not filter points by trajectory length
|
550 |
+
long_traj_indices = valid_indices
|
551 |
+
else:
|
552 |
+
# Filter points by trajectory length
|
553 |
+
long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
|
554 |
+
|
555 |
+
# Randomly sample from the filtered points
|
556 |
+
if len(long_traj_indices) > 0:
|
557 |
+
# Random sampling without replacement
|
558 |
+
selected_indices = np.random.choice(
|
559 |
+
len(long_traj_indices),
|
560 |
+
min(points_to_visualize, len(long_traj_indices)),
|
561 |
+
replace=False
|
562 |
+
)
|
563 |
+
# Get the actual indices in their original order
|
564 |
+
valid_point_indices = long_traj_indices[np.sort(selected_indices)]
|
565 |
+
else:
|
566 |
+
valid_point_indices = np.array([])
|
567 |
+
else:
|
568 |
+
valid_point_indices = np.array([])
|
569 |
+
|
570 |
+
if len(valid_point_indices) > 0:
|
571 |
+
# Get trajectories for all valid points
|
572 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
|
573 |
+
N_point = trajectories.shape[1]
|
574 |
+
if color_code == "rainbow":
|
575 |
+
point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
|
576 |
+
elif color_code == "jet":
|
577 |
+
point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
|
578 |
+
# Modify the loop to handle frames less than fixed_length_traj
|
579 |
+
for i in range(traj_end_frame - traj_start_frame):
|
580 |
+
# Calculate the actual trajectory length for this frame
|
581 |
+
actual_length = min(fixed_length_traj, i + 1)
|
582 |
+
|
583 |
+
if actual_length > 1: # Need at least 2 points to form a line
|
584 |
+
# Get the appropriate slice of trajectory data
|
585 |
+
start_idx = max(0, i - actual_length + 1)
|
586 |
+
end_idx = i + 1
|
587 |
+
|
588 |
+
# Create line segments between consecutive frames
|
589 |
+
traj_slice = trajectories[start_idx:end_idx]
|
590 |
+
line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
|
591 |
+
line_points = line_points.reshape(-1, 2, 3)
|
592 |
+
|
593 |
+
# Create corresponding colors
|
594 |
+
line_colors = np.tile(point_colors, (actual_length-1, 1))
|
595 |
+
line_colors = np.stack([line_colors, line_colors], axis=1)
|
596 |
+
|
597 |
+
# Add line segments
|
598 |
+
server.scene.add_line_segments(
|
599 |
+
name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
|
600 |
+
points=line_points,
|
601 |
+
colors=line_colors,
|
602 |
+
line_width=traj_line_width,
|
603 |
+
visible=gui_show_trajectories.value
|
604 |
+
)
|
605 |
+
|
606 |
+
# Add trajectory controls functionality
|
607 |
+
@gui_show_trajectories.on_update
|
608 |
+
def _(_) -> None:
|
609 |
+
with server.atomic():
|
610 |
+
# Remove all existing trajectories
|
611 |
+
for i in range(num_frames):
|
612 |
+
try:
|
613 |
+
server.scene.remove_by_name(f"/frames/t{i}/head1/trajectory")
|
614 |
+
except KeyError:
|
615 |
+
pass
|
616 |
+
|
617 |
+
# Create new trajectories if enabled
|
618 |
+
if gui_show_trajectories.value and traj_3d_head1 is not None:
|
619 |
+
# Get the mask for the last frame and reshape it
|
620 |
+
last_frame_mask = masks[traj_end_frame-1].reshape(-1)
|
621 |
+
|
622 |
+
# Calculate trajectory lengths
|
623 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame]
|
624 |
+
traj_diffs = np.diff(trajectories, axis=0)
|
625 |
+
traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0)
|
626 |
+
|
627 |
+
# Get points that are within the mask
|
628 |
+
valid_indices = np.where(last_frame_mask)[0]
|
629 |
+
|
630 |
+
if len(valid_indices) > 0:
|
631 |
+
# Filter by trajectory length
|
632 |
+
masked_traj_lengths = traj_lengths[valid_indices]
|
633 |
+
avg_traj_length = np.mean(masked_traj_lengths)
|
634 |
+
long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
|
635 |
+
|
636 |
+
# Randomly sample from the filtered points
|
637 |
+
if len(long_traj_indices) > 0:
|
638 |
+
# Random sampling without replacement
|
639 |
+
selected_indices = np.random.choice(
|
640 |
+
len(long_traj_indices),
|
641 |
+
min(points_to_visualize, len(long_traj_indices)),
|
642 |
+
replace=False
|
643 |
+
)
|
644 |
+
# Get the actual indices in their original order
|
645 |
+
valid_point_indices = long_traj_indices[np.sort(selected_indices)]
|
646 |
+
else:
|
647 |
+
valid_point_indices = np.array([])
|
648 |
+
else:
|
649 |
+
valid_point_indices = np.array([])
|
650 |
+
|
651 |
+
if len(valid_point_indices) > 0:
|
652 |
+
# Get trajectories for all valid points
|
653 |
+
trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
|
654 |
+
N_point = trajectories.shape[1]
|
655 |
+
|
656 |
+
if color_code == "rainbow":
|
657 |
+
point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
|
658 |
+
elif color_code == "jet":
|
659 |
+
point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
|
660 |
+
|
661 |
+
# Modify the loop to handle frames less than fixed_length_traj
|
662 |
+
for i in range(traj_end_frame - traj_start_frame):
|
663 |
+
# Calculate the actual trajectory length for this frame
|
664 |
+
actual_length = min(fixed_length_traj, i + 1)
|
665 |
+
|
666 |
+
if actual_length > 1: # Need at least 2 points to form a line
|
667 |
+
# Get the appropriate slice of trajectory data
|
668 |
+
start_idx = max(0, i - actual_length + 1)
|
669 |
+
end_idx = i + 1
|
670 |
+
|
671 |
+
# Create line segments between consecutive frames
|
672 |
+
traj_slice = trajectories[start_idx:end_idx]
|
673 |
+
line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
|
674 |
+
line_points = line_points.reshape(-1, 2, 3)
|
675 |
+
|
676 |
+
# Create corresponding colors
|
677 |
+
line_colors = np.tile(point_colors, (actual_length-1, 1))
|
678 |
+
line_colors = np.stack([line_colors, line_colors], axis=1)
|
679 |
+
|
680 |
+
# Add line segments
|
681 |
+
server.scene.add_line_segments(
|
682 |
+
name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
|
683 |
+
points=line_points,
|
684 |
+
colors=line_colors,
|
685 |
+
line_width=traj_line_width,
|
686 |
+
visible=True
|
687 |
+
)
|
688 |
+
|
689 |
+
# Update color tinting when the checkbox changes
|
690 |
+
@gui_use_color_tint.on_update
|
691 |
+
def _(_) -> None:
|
692 |
+
with server.atomic():
|
693 |
+
for i in range(num_frames):
|
694 |
+
# Update head1 point cloud
|
695 |
+
if traj_3d_head1 is not None:
|
696 |
+
position = xyz_head1[i]
|
697 |
+
color = rgb_head1[i]
|
698 |
+
if conf_mask_head1 is not None:
|
699 |
+
position = position[conf_mask_head1[i]]
|
700 |
+
color = color[conf_mask_head1[i]]
|
701 |
+
|
702 |
+
color_head1 = color.copy()
|
703 |
+
if gui_use_color_tint.value:
|
704 |
+
color_head1 *= blend_ratio
|
705 |
+
color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
|
706 |
+
color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
|
707 |
+
color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
|
708 |
+
|
709 |
+
server.scene.remove_by_name(f"/frames/t{i}/head1/point_cloud")
|
710 |
+
server.scene.add_point_cloud(
|
711 |
+
name=f"/frames/t{i}/head1/point_cloud",
|
712 |
+
points=position[::downsample_factor],
|
713 |
+
colors=color_head1[::downsample_factor],
|
714 |
+
point_size=point_size,
|
715 |
+
point_shape="rounded",
|
716 |
+
)
|
717 |
+
|
718 |
+
# Update head2 point cloud
|
719 |
+
if traj_3d_head2 is not None:
|
720 |
+
position = xyz_head2[i]
|
721 |
+
color = rgb_head2[i]
|
722 |
+
if conf_mask_head2 is not None:
|
723 |
+
position = position[conf_mask_head2[i]]
|
724 |
+
color = color[conf_mask_head2[i]]
|
725 |
+
|
726 |
+
color_head2 = color.copy()
|
727 |
+
if gui_use_color_tint.value:
|
728 |
+
color_head2 *= blend_ratio
|
729 |
+
color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
|
730 |
+
color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
|
731 |
+
color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
|
732 |
+
|
733 |
+
server.scene.remove_by_name(f"/frames/t{i}/head2/point_cloud")
|
734 |
+
server.scene.add_point_cloud(
|
735 |
+
name=f"/frames/t{i}/head2/point_cloud",
|
736 |
+
points=position[::downsample_factor],
|
737 |
+
colors=color_head2[::downsample_factor],
|
738 |
+
point_size=point_size,
|
739 |
+
point_shape="rounded",
|
740 |
+
)
|
741 |
+
|
742 |
+
# Initialize video preview
|
743 |
+
if raw_video is not None:
|
744 |
+
video_preview.image = process_video_frame(0)
|
745 |
+
|
746 |
+
# Update video preview when timestep changes
|
747 |
+
@gui_timestep.on_update
|
748 |
+
def _(_) -> None:
|
749 |
+
current_timestep = gui_timestep.value
|
750 |
+
if raw_video is not None:
|
751 |
+
video_preview.image = process_video_frame(current_timestep)
|
752 |
+
|
753 |
+
# Playback update loop.
|
754 |
+
log_memory_usage("before starting playback loop")
|
755 |
+
|
756 |
+
prev_timestep = gui_timestep.value
|
757 |
+
while True:
|
758 |
+
current_timestep = gui_timestep.value
|
759 |
+
|
760 |
+
# If timestep changes, update frame visibility
|
761 |
+
if current_timestep != prev_timestep:
|
762 |
+
with server.atomic():
|
763 |
+
# ... existing code ...
|
764 |
+
|
765 |
+
# Update video preview
|
766 |
+
if raw_video is not None:
|
767 |
+
video_preview.image = process_video_frame(current_timestep)
|
768 |
+
|
769 |
+
# Update in playback mode
|
770 |
+
if gui_playing.value and not gui_show_all_frames.value:
|
771 |
+
gui_timestep.value = (gui_timestep.value + 1) % num_frames
|
772 |
+
|
773 |
+
# Update video preview in playback mode
|
774 |
+
if raw_video is not None:
|
775 |
+
video_preview.image = process_video_frame(gui_timestep.value)
|
776 |
+
|
777 |
+
time.sleep(1.0 / gui_framerate.value)
|
778 |
+
|
779 |
+
|
780 |
+
if __name__ == "__main__":
|
781 |
+
tyro.cli(visualize_st4rtrack)
|
viser_proxy_manager.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
import httpx
|
4 |
+
import viser
|
5 |
+
import websockets
|
6 |
+
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
7 |
+
from fastapi.responses import Response
|
8 |
+
|
9 |
+
|
10 |
+
class ViserProxyManager:
|
11 |
+
"""Manages Viser server instances for Gradio applications.
|
12 |
+
|
13 |
+
This class handles the creation, retrieval, and cleanup of Viser server instances,
|
14 |
+
as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
app: The FastAPI application to which the proxy routes will be added.
|
18 |
+
min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
|
19 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
20 |
+
max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
|
21 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
22 |
+
max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
app: FastAPI,
|
28 |
+
min_local_port: int = 8000,
|
29 |
+
max_local_port: int = 9000,
|
30 |
+
max_message_size: int = 300 * 1024 * 1024, # 300MB default
|
31 |
+
) -> None:
|
32 |
+
self._min_port = min_local_port
|
33 |
+
self._max_port = max_local_port
|
34 |
+
self._max_message_size = max_message_size
|
35 |
+
self._server_from_session_hash: dict[str, viser.ViserServer] = {}
|
36 |
+
self._last_port = self._min_port - 1 # Track last port tried
|
37 |
+
|
38 |
+
@app.get("/viser/{server_id}/{proxy_path:path}")
|
39 |
+
async def proxy(request: Request, server_id: str, proxy_path: str):
|
40 |
+
"""Proxy HTTP requests to the appropriate Viser server."""
|
41 |
+
# Get the local port for this server ID
|
42 |
+
server = self._server_from_session_hash.get(server_id)
|
43 |
+
if server is None:
|
44 |
+
return Response(content="Server not found", status_code=404)
|
45 |
+
|
46 |
+
# Build target URL
|
47 |
+
if proxy_path:
|
48 |
+
path_suffix = f"/{proxy_path}"
|
49 |
+
else:
|
50 |
+
path_suffix = "/"
|
51 |
+
|
52 |
+
target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
|
53 |
+
if request.url.query:
|
54 |
+
target_url += f"?{request.url.query}"
|
55 |
+
|
56 |
+
# Forward request
|
57 |
+
async with httpx.AsyncClient() as client:
|
58 |
+
# Forward the original headers, but remove any problematic ones
|
59 |
+
headers = dict(request.headers)
|
60 |
+
headers.pop("host", None) # Remove host header to avoid conflicts
|
61 |
+
headers["accept-encoding"] = "identity" # Disable compression
|
62 |
+
|
63 |
+
proxied_req = client.build_request(
|
64 |
+
method=request.method,
|
65 |
+
url=target_url,
|
66 |
+
headers=headers,
|
67 |
+
content=await request.body(),
|
68 |
+
)
|
69 |
+
proxied_resp = await client.send(proxied_req, stream=True)
|
70 |
+
|
71 |
+
# Get response headers
|
72 |
+
response_headers = dict(proxied_resp.headers)
|
73 |
+
|
74 |
+
# Check if this is an HTML response
|
75 |
+
content = await proxied_resp.aread()
|
76 |
+
return Response(
|
77 |
+
content=content,
|
78 |
+
status_code=proxied_resp.status_code,
|
79 |
+
headers=response_headers,
|
80 |
+
)
|
81 |
+
|
82 |
+
# WebSocket Proxy
|
83 |
+
@app.websocket("/viser/{server_id}")
|
84 |
+
async def websocket_proxy(websocket: WebSocket, server_id: str):
|
85 |
+
"""Proxy WebSocket connections to the appropriate Viser server."""
|
86 |
+
try:
|
87 |
+
await websocket.accept()
|
88 |
+
|
89 |
+
server = self._server_from_session_hash.get(server_id)
|
90 |
+
if server is None:
|
91 |
+
await websocket.close(code=1008, reason="Not Found")
|
92 |
+
return
|
93 |
+
|
94 |
+
# Determine target WebSocket URL
|
95 |
+
target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
|
96 |
+
|
97 |
+
if not target_ws_url:
|
98 |
+
await websocket.close(code=1008, reason="Not Found")
|
99 |
+
return
|
100 |
+
|
101 |
+
try:
|
102 |
+
# Connect to the target WebSocket with increased message size and timeout
|
103 |
+
async with websockets.connect(
|
104 |
+
target_ws_url,
|
105 |
+
max_size=self._max_message_size,
|
106 |
+
ping_interval=30, # Send ping every 30 seconds
|
107 |
+
ping_timeout=10, # Wait 10 seconds for pong response
|
108 |
+
close_timeout=5, # Wait 5 seconds for close handshake
|
109 |
+
) as ws_target:
|
110 |
+
# Create tasks for bidirectional communication
|
111 |
+
async def forward_to_target():
|
112 |
+
"""Forward messages from the client to the target WebSocket."""
|
113 |
+
try:
|
114 |
+
while True:
|
115 |
+
data = await websocket.receive_bytes()
|
116 |
+
await ws_target.send(data, text=False)
|
117 |
+
except WebSocketDisconnect:
|
118 |
+
try:
|
119 |
+
await ws_target.close()
|
120 |
+
except RuntimeError:
|
121 |
+
pass
|
122 |
+
|
123 |
+
async def forward_from_target():
|
124 |
+
"""Forward messages from the target WebSocket to the client."""
|
125 |
+
try:
|
126 |
+
while True:
|
127 |
+
data = await ws_target.recv(decode=False)
|
128 |
+
await websocket.send_bytes(data)
|
129 |
+
except websockets.exceptions.ConnectionClosed:
|
130 |
+
try:
|
131 |
+
await websocket.close()
|
132 |
+
except RuntimeError:
|
133 |
+
pass
|
134 |
+
|
135 |
+
# Run both forwarding tasks concurrently
|
136 |
+
forward_task = asyncio.create_task(forward_to_target())
|
137 |
+
backward_task = asyncio.create_task(forward_from_target())
|
138 |
+
|
139 |
+
# Wait for either task to complete (which means a connection was closed)
|
140 |
+
done, pending = await asyncio.wait(
|
141 |
+
[forward_task, backward_task],
|
142 |
+
return_when=asyncio.FIRST_COMPLETED,
|
143 |
+
)
|
144 |
+
|
145 |
+
# Cancel the remaining task
|
146 |
+
for task in pending:
|
147 |
+
task.cancel()
|
148 |
+
|
149 |
+
except websockets.exceptions.ConnectionClosedError as e:
|
150 |
+
print(f"WebSocket connection closed with error: {e}")
|
151 |
+
await websocket.close(code=1011, reason="Connection to target closed")
|
152 |
+
|
153 |
+
except Exception as e:
|
154 |
+
print(f"WebSocket proxy error: {e}")
|
155 |
+
try:
|
156 |
+
await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length
|
157 |
+
except:
|
158 |
+
pass # Already closed
|
159 |
+
|
160 |
+
def start_server(self, server_id: str) -> viser.ViserServer:
|
161 |
+
"""Start a new Viser server and associate it with the given server ID.
|
162 |
+
|
163 |
+
Finds an available port within the configured min_local_port and max_local_port range.
|
164 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
server_id: The unique identifier to associate with the new server.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
The newly created Viser server instance.
|
171 |
+
|
172 |
+
Raises:
|
173 |
+
RuntimeError: If no free ports are available in the configured range.
|
174 |
+
"""
|
175 |
+
import socket
|
176 |
+
|
177 |
+
# Start searching from the last port + 1 (with wraparound)
|
178 |
+
port_range_size = self._max_port - self._min_port + 1
|
179 |
+
start_port = (
|
180 |
+
(self._last_port + 1 - self._min_port) % port_range_size
|
181 |
+
) + self._min_port
|
182 |
+
|
183 |
+
# Try each port once
|
184 |
+
for offset in range(port_range_size):
|
185 |
+
port = (
|
186 |
+
(start_port - self._min_port + offset) % port_range_size
|
187 |
+
) + self._min_port
|
188 |
+
try:
|
189 |
+
# Check if port is available by attempting to bind to it
|
190 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
191 |
+
s.bind(("127.0.0.1", port))
|
192 |
+
# Port is available, create server with this port
|
193 |
+
server = viser.ViserServer(port=port)
|
194 |
+
self._server_from_session_hash[server_id] = server
|
195 |
+
self._last_port = port
|
196 |
+
return server
|
197 |
+
except OSError:
|
198 |
+
# Port is in use, try the next one
|
199 |
+
continue
|
200 |
+
|
201 |
+
# If we get here, no ports were available
|
202 |
+
raise RuntimeError(
|
203 |
+
f"No available local ports in range {self._min_port}-{self._max_port}"
|
204 |
+
)
|
205 |
+
|
206 |
+
def get_server(self, server_id: str) -> viser.ViserServer:
|
207 |
+
"""Retrieve a Viser server instance by its ID.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
server_id: The unique identifier of the server to retrieve.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
The Viser server instance associated with the given ID.
|
214 |
+
"""
|
215 |
+
return self._server_from_session_hash[server_id]
|
216 |
+
|
217 |
+
def stop_server(self, server_id: str) -> None:
|
218 |
+
"""Stop a Viser server and remove it from the manager.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
server_id: The unique identifier of the server to stop.
|
222 |
+
"""
|
223 |
+
self._server_from_session_hash.pop(server_id).stop()
|