Junyi42 commited on
Commit
a9472e6
·
1 Parent(s): 0f252ee
Files changed (6) hide show
  1. Dockerfile +22 -0
  2. README.md +30 -8
  3. app.py +105 -0
  4. requirements.txt +15 -0
  5. vis_st4rtrack.py +781 -0
  6. viser_proxy_manager.py +223 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for OpenCV and build tools
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1-mesa-glx \
8
+ libglib2.0-0 \
9
+ build-essential \
10
+ git \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ COPY . .
17
+
18
+ # Make port 7860 available to the world outside the container
19
+ EXPOSE 7860
20
+
21
+ # Command to run when the container starts
22
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,34 @@
1
  ---
2
- title: Viser Bonn
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.25.2
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Viser Gradio Embed
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: docker
7
+ app_port: 7860
 
8
  pinned: false
9
  ---
10
 
11
+ # Viser + Gradio
12
+
13
+ Demo for integrating [viser](https://github.com/nerfstudio-project/viser) 3D
14
+ visualizations into a [Gradio](https://www.gradio.app/) application.
15
+
16
+ - Uses Gradio's session management to create isolated 3D visualization contexts.
17
+ - Exposes both Gradio and Viser over the same port.
18
+
19
+ ## Deploying on HuggingFace Spaces
20
+
21
+ **[ [Live example](https://huggingface.co/spaces/brentyi/viser-gradio-embed) ]**
22
+
23
+ This repository should work out-of-the-box with HF Spaces via Docker.
24
+
25
+ - Unlike a vanilla Gradio Space, this is unfortunately not supported by [ZeroGPU](https://huggingface.co/docs/hub/en/spaces-zerogpu).
26
+
27
+ ## Local Demo
28
+
29
+ ```bash
30
+ pip install -r requirements.txt
31
+ python app.py
32
+ ```
33
+
34
+ https://github.com/user-attachments/assets/b94a117a-b9e5-4854-805a-8666941c7816
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import threading
3
+ import psutil
4
+ import fastapi
5
+ import gradio as gr
6
+ import uvicorn
7
+
8
+ from viser_proxy_manager import ViserProxyManager
9
+ from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage
10
+
11
+ # Global cache for loaded data
12
+ global_data_cache = None
13
+
14
+ def check_ram_usage(threshold_percent=90):
15
+ """Check if RAM usage is above the threshold.
16
+
17
+ Args:
18
+ threshold_percent: Maximum RAM usage percentage allowed
19
+
20
+ Returns:
21
+ bool: True if RAM usage is below threshold, False otherwise
22
+ """
23
+ ram_percent = psutil.virtual_memory().percent
24
+ print(f"Current RAM usage: {ram_percent}%")
25
+ return ram_percent < threshold_percent
26
+
27
+
28
+ def main() -> None:
29
+ # Load data once at startup using the function from vis_st4rtrack.py
30
+ global global_data_cache
31
+ global_data_cache = load_trajectory_data(use_float16=True, max_frames=120, traj_path="bonn_results", mask_folder="./train")
32
+
33
+ app = fastapi.FastAPI()
34
+ viser_manager = ViserProxyManager(app)
35
+
36
+ # Create a Gradio interface with title, iframe, and buttons
37
+ with gr.Blocks(title="Viser Viewer") as demo:
38
+ # Add the iframe with a border
39
+ iframe_html = gr.HTML("")
40
+ status_text = gr.Markdown("") # Add status text component
41
+
42
+ @demo.load(outputs=[iframe_html, status_text])
43
+ def start_server(request: gr.Request):
44
+ assert request.session_hash is not None
45
+
46
+ # Check RAM usage before starting visualization
47
+ if not check_ram_usage(threshold_percent=100):
48
+ return """
49
+ <div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
50
+ <h2>⚠️ Server is currently under high load</h2>
51
+ <p>Please try again later when resources are available.</p>
52
+ </div>
53
+ """, "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
54
+
55
+ viser_manager.start_server(request.session_hash)
56
+
57
+ # Use the request's base URL if available
58
+ host = request.headers["host"]
59
+
60
+ # Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
61
+ protocol = (
62
+ "https"
63
+ if request.headers.get("x-forwarded-proto") == "https"
64
+ else "http"
65
+ )
66
+
67
+ # Add visualization in a separate thread
68
+ server = viser_manager.get_server(request.session_hash)
69
+ threading.Thread(
70
+ target=visualize_st4rtrack,
71
+ kwargs={
72
+ "server": server,
73
+ "use_float16": True,
74
+ "preloaded_data": global_data_cache, # Pass the preloaded data
75
+ "color_code": "jet",
76
+ "blue_rgb": (0.0, 0.149, 0.463), # #002676
77
+ "red_rgb": (0.769, 0.510, 0.055), # #FDB515
78
+ "blend_ratio": 0.7
79
+ },
80
+ daemon=True
81
+ ).start()
82
+
83
+ return f"""
84
+ <iframe
85
+ src="{protocol}://{host}/viser/{request.session_hash}/"
86
+ width="100%"
87
+ height="500px"
88
+ frameborder="0"
89
+ style="display: block;"
90
+ allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
91
+ loading="lazy"
92
+ ></iframe>
93
+ """, "**System Status:** Visualization loaded successfully."
94
+
95
+ @demo.unload
96
+ def stop(request: gr.Request):
97
+ assert request.session_hash is not None
98
+ viser_manager.stop_server(request.session_hash)
99
+
100
+ gr.mount_gradio_app(app, demo, "/")
101
+ uvicorn.run(app, host="0.0.0.0", port=7860)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # git+https://github.com/nerfstudio-project/viser.git
2
+ viser>=0.2.23
3
+ gradio==5.23.1
4
+ fastapi==0.115.11
5
+ uvicorn==0.34.0
6
+ httpx==0.27.2
7
+ websockets==15.0.1
8
+ tyro==0.4.1
9
+ numpy>=1.20.0
10
+ tqdm>=4.62.0
11
+ opencv-python>=4.5.0
12
+ imageio>=2.25.0
13
+ matplotlib>=3.5.0
14
+ pyliblzfse>=0.1.0
15
+ psutil>=5.9.0
vis_st4rtrack.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Record3D visualizer
2
+
3
+ Parse and stream record3d captures. To get the demo data, see `./assets/download_record3d_dance.sh`.
4
+ """
5
+
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import numpy as onp
10
+ import tyro
11
+ import cv2
12
+ from tqdm.auto import tqdm
13
+
14
+ import viser
15
+ import viser.extras
16
+ import viser.transforms as tf
17
+
18
+ from glob import glob
19
+ import numpy as np
20
+ import imageio.v3 as iio
21
+ import matplotlib.pyplot as plt
22
+ import psutil
23
+
24
+ def log_memory_usage(message=""):
25
+ """Log current memory usage with an optional message."""
26
+ process = psutil.Process()
27
+ memory_info = process.memory_info()
28
+ memory_mb = memory_info.rss / (1024 * 1024) # Convert to MB
29
+ print(f"Memory usage {message}: {memory_mb:.2f} MB")
30
+
31
+ def load_trajectory_data(traj_path="results", use_float16=True, max_frames=None, mask_folder='./train'):
32
+ """Load trajectory data from files.
33
+
34
+ Args:
35
+ traj_path: Path to the directory containing trajectory data
36
+ use_float16: Whether to convert data to float16 to save memory
37
+ max_frames: Maximum number of frames to load (None for all)
38
+ mask_folder: Path to the directory containing mask images
39
+
40
+ Returns:
41
+ A dictionary containing loaded data
42
+ """
43
+ log_memory_usage("before loading data")
44
+
45
+ data_cache = {
46
+ 'traj_3d_head1': None,
47
+ 'traj_3d_head2': None,
48
+ 'conf_mask_head1': None,
49
+ 'conf_mask_head2': None,
50
+ 'masks': None,
51
+ 'raw_video': None,
52
+ 'loaded': False
53
+ }
54
+
55
+ # Load masks
56
+ masks_paths = sorted(glob(mask_folder + '/*.jpg'))
57
+ masks = None
58
+
59
+ if masks_paths:
60
+ masks = [iio.imread(p) for p in masks_paths]
61
+ masks = np.stack(masks, axis=0)
62
+ # Convert masks to binary (0 or 1)
63
+ masks = (masks < 1).astype(np.float32)
64
+ masks = masks.sum(axis=-1) > 2 # Combine all channels, True where any channel was 1
65
+ print(f"Original masks shape: {masks.shape}")
66
+ else:
67
+ print("No masks found. Will create default masks when needed.")
68
+
69
+ data_cache['masks'] = masks
70
+
71
+ if Path(traj_path).is_dir():
72
+ # Find all trajectory files
73
+ traj_3d_paths_head1 = sorted(glob(traj_path + '/pts3d1_p*.npy'),
74
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
75
+ conf_paths_head1 = sorted(glob(traj_path + '/conf1_p*.npy'),
76
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
77
+
78
+ traj_3d_paths_head2 = sorted(glob(traj_path + '/pts3d2_p*.npy'),
79
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
80
+ conf_paths_head2 = sorted(glob(traj_path + '/conf2_p*.npy'),
81
+ key=lambda x: int(x.split('_p')[-1].split('.')[0]))
82
+
83
+ # Limit number of frames if specified
84
+ if max_frames is not None:
85
+ traj_3d_paths_head1 = traj_3d_paths_head1[:max_frames]
86
+ conf_paths_head1 = conf_paths_head1[:max_frames] if conf_paths_head1 else []
87
+ traj_3d_paths_head2 = traj_3d_paths_head2[:max_frames]
88
+ conf_paths_head2 = conf_paths_head2[:max_frames] if conf_paths_head2 else []
89
+
90
+ # Process head1
91
+ if traj_3d_paths_head1:
92
+ if use_float16:
93
+ traj_3d_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head1], axis=0)
94
+ else:
95
+ traj_3d_head1 = onp.stack([onp.load(p) for p in traj_3d_paths_head1], axis=0)
96
+
97
+ log_memory_usage("after loading head1 data")
98
+
99
+ h, w, _ = traj_3d_head1.shape[1:]
100
+ num_frames = traj_3d_head1.shape[0]
101
+
102
+ # If masks is None, create default masks (all ones)
103
+ if masks is None:
104
+ masks = np.ones((num_frames, h, w), dtype=bool)
105
+ print(f"Created default masks with shape: {masks.shape}")
106
+ data_cache['masks'] = masks
107
+ else:
108
+ # Resize masks to match trajectory dimensions using nearest neighbor interpolation
109
+ masks_resized = np.zeros((masks.shape[0], h, w), dtype=bool)
110
+ for i in range(masks.shape[0]):
111
+ masks_resized[i] = cv2.resize(
112
+ masks[i].astype(np.uint8),
113
+ (w, h),
114
+ interpolation=cv2.INTER_NEAREST
115
+ ).astype(bool)
116
+
117
+ print(f"Resized masks shape: {masks_resized.shape}")
118
+ data_cache['masks'] = masks_resized
119
+
120
+ # Reshape trajectory data
121
+ traj_3d_head1 = traj_3d_head1.reshape(traj_3d_head1.shape[0], -1, 6)
122
+ data_cache['traj_3d_head1'] = traj_3d_head1
123
+
124
+ if conf_paths_head1:
125
+ conf_head1 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head1], axis=0)
126
+ conf_head1 = conf_head1.reshape(conf_head1.shape[0], -1)
127
+ conf_head1 = conf_head1.mean(axis=0)
128
+ # repeat the conf_head1 to match the number of frames in the dimension 0
129
+ conf_head1 = np.tile(conf_head1, (num_frames, 1))
130
+ # Convert to float32 before calculating percentile to avoid overflow
131
+ conf_thre = np.percentile(conf_head1.astype(np.float32), 1) # Default percentile
132
+ conf_mask_head1 = conf_head1 > conf_thre
133
+ data_cache['conf_mask_head1'] = conf_mask_head1
134
+
135
+ # Process head2
136
+ if traj_3d_paths_head2:
137
+ if use_float16:
138
+ traj_3d_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in traj_3d_paths_head2], axis=0)
139
+ else:
140
+ traj_3d_head2 = onp.stack([onp.load(p) for p in traj_3d_paths_head2], axis=0)
141
+
142
+ log_memory_usage("after loading head2 data")
143
+
144
+ # Store raw video data
145
+ raw_video = traj_3d_head2[:, :, :, 3:6] # [num_frames, h, w, 3]
146
+ data_cache['raw_video'] = raw_video
147
+
148
+ traj_3d_head2 = traj_3d_head2.reshape(traj_3d_head2.shape[0], -1, 6)
149
+ data_cache['traj_3d_head2'] = traj_3d_head2
150
+
151
+ if conf_paths_head2:
152
+ conf_head2 = onp.stack([onp.load(p).astype(onp.float16) for p in conf_paths_head2], axis=0)
153
+ conf_head2 = conf_head2.reshape(conf_head2.shape[0], -1)
154
+ # set conf thre to be 1 percentile of the conf_head2, for each frame
155
+ conf_thre = np.percentile(conf_head2.astype(np.float32), 1, axis=1)
156
+ conf_mask_head2 = conf_head2 > conf_thre[:, None]
157
+ data_cache['conf_mask_head2'] = conf_mask_head2
158
+
159
+ data_cache['loaded'] = True
160
+ log_memory_usage("after loading all data")
161
+ return data_cache
162
+
163
+ def visualize_st4rtrack(
164
+ traj_path: str = "results",
165
+ up_dir: str = "-z", # should be +z or -z
166
+ max_frames: int = 100,
167
+ share: bool = False,
168
+ point_size: float = 0.005,
169
+ downsample_factor: int = 3,
170
+ num_traj_points: int = 100,
171
+ conf_thre_percentile: float = 1,
172
+ traj_end_frame: int = 100,
173
+ traj_start_frame: int = 0,
174
+ traj_line_width: float = 3.,
175
+ fixed_length_traj: int = 20,
176
+ server: viser.ViserServer = None,
177
+ use_float16: bool = True,
178
+ preloaded_data: dict = None, # Add this parameter to accept preloaded data
179
+ color_code: str = "jet",
180
+ # Updated hex colors: #002676 for blue and #FDB515 for red/gold
181
+ blue_rgb: tuple[float, float, float] = (0.0, 0.149, 0.463), # #002676
182
+ red_rgb: tuple[float, float, float] = (0.769, 0.510, 0.055), # #FDB515
183
+ blend_ratio: float = 0.7,
184
+ mask_folder: str = None,
185
+ mid_anchor: bool = False,
186
+ video_width: int = 320, # Video display width
187
+ video_height: int = 180, # Video display height
188
+ camera_position: tuple[float, float, float] = (1e-3, 1.5, -0.2),
189
+ ) -> None:
190
+ log_memory_usage("at start of visualization")
191
+
192
+ if server is None:
193
+ server = viser.ViserServer()
194
+ if share:
195
+ server.request_share_url()
196
+
197
+ @server.on_client_connect
198
+ def _(client: viser.ClientHandle) -> None:
199
+ client.camera.position = camera_position
200
+ client.camera.look_at = (0, 0, 0)
201
+
202
+ # Configure the GUI panel size and layout
203
+ server.gui.configure_theme(
204
+ control_layout="collapsible",
205
+ control_width="small",
206
+ dark_mode=False,
207
+ show_logo=False,
208
+ show_share_button=True
209
+ )
210
+
211
+ # Add video preview to the GUI panel - placed at the top
212
+ video_preview = server.gui.add_image(
213
+ np.zeros((video_height, video_width, 3), dtype=np.uint8), # Initial blank image
214
+ format="jpeg"
215
+ )
216
+
217
+ # Use preloaded data if available
218
+ if preloaded_data and preloaded_data.get('loaded', False):
219
+ traj_3d_head1 = preloaded_data.get('traj_3d_head1')
220
+ traj_3d_head2 = preloaded_data.get('traj_3d_head2')
221
+ conf_mask_head1 = preloaded_data.get('conf_mask_head1')
222
+ conf_mask_head2 = preloaded_data.get('conf_mask_head2')
223
+ masks = preloaded_data.get('masks')
224
+ raw_video = preloaded_data.get('raw_video')
225
+ print("Using preloaded data!")
226
+ else:
227
+ # Load data using the shared function
228
+ print("No preloaded data available, loading from files...")
229
+ data = load_trajectory_data(traj_path, use_float16, max_frames, mask_folder)
230
+ traj_3d_head1 = data.get('traj_3d_head1')
231
+ traj_3d_head2 = data.get('traj_3d_head2')
232
+ conf_mask_head1 = data.get('conf_mask_head1')
233
+ conf_mask_head2 = data.get('conf_mask_head2')
234
+ masks = data.get('masks')
235
+ raw_video = data.get('raw_video')
236
+
237
+ def process_video_frame(frame_idx):
238
+ if raw_video is None:
239
+ return np.zeros((video_height, video_width, 3), dtype=np.uint8)
240
+
241
+ # Get the original frame
242
+ raw_frame = raw_video[frame_idx]
243
+
244
+ # Adjust value range to 0-255
245
+ if raw_frame.max() <= 1.0:
246
+ frame = (raw_frame * 255).astype(np.uint8)
247
+ else:
248
+ frame = raw_frame.astype(np.uint8)
249
+
250
+ # Resize to fit the preview window
251
+ h, w = frame.shape[:2]
252
+ # Calculate size while maintaining aspect ratio
253
+ if h/w > video_height/video_width: # Height limited
254
+ new_h = video_height
255
+ new_w = int(w * (new_h / h))
256
+ else: # Width limited
257
+ new_w = video_width
258
+ new_h = int(h * (new_w / w))
259
+
260
+ # Resize
261
+ resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
262
+
263
+ # Create a black background
264
+ display_frame = np.zeros((video_height, video_width, 3), dtype=np.uint8)
265
+
266
+ # Place the resized frame in the center
267
+ y_offset = (video_height - new_h) // 2
268
+ x_offset = (video_width - new_w) // 2
269
+ display_frame[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_frame
270
+
271
+ return display_frame
272
+
273
+ server.scene.set_up_direction(up_dir)
274
+ print("Setting up visualization!")
275
+
276
+ # Add visualization controls
277
+ with server.gui.add_folder("Visualization"):
278
+ gui_show_head1 = server.gui.add_checkbox("Tracking Points", True)
279
+ gui_show_head2 = server.gui.add_checkbox("Recon Points", True)
280
+ gui_show_trajectories = server.gui.add_checkbox("Trajectories", True)
281
+ gui_use_color_tint = server.gui.add_checkbox("Use Color Tint", True)
282
+
283
+ # Process and center point clouds
284
+ center_point = None
285
+ if traj_3d_head1 is not None:
286
+ xyz_head1 = traj_3d_head1[:, :, :3]
287
+ rgb_head1 = traj_3d_head1[:, :, 3:6]
288
+ if center_point is None:
289
+ center_point = onp.mean(xyz_head1, axis=(0, 1), keepdims=True)
290
+ xyz_head1 -= center_point
291
+ if rgb_head1.sum(axis=(-1)).max() > 125:
292
+ rgb_head1 /= 255.0
293
+
294
+ if traj_3d_head2 is not None:
295
+ xyz_head2 = traj_3d_head2[:, :, :3]
296
+ rgb_head2 = traj_3d_head2[:, :, 3:6]
297
+ if center_point is None:
298
+ center_point = onp.mean(xyz_head2, axis=(0, 1), keepdims=True)
299
+ xyz_head2 -= center_point
300
+ if rgb_head2.sum(axis=(-1)).max() > 125:
301
+ rgb_head2 /= 255.0
302
+
303
+ # Determine number of frames
304
+ F = max(
305
+ traj_3d_head1.shape[0] if traj_3d_head1 is not None else 0,
306
+ traj_3d_head2.shape[0] if traj_3d_head2 is not None else 0
307
+ )
308
+ num_frames = min(max_frames, F)
309
+ traj_end_frame = min(traj_end_frame, num_frames)
310
+ print(f"Number of frames: {num_frames}")
311
+ xyz_head1 = xyz_head1[:num_frames]
312
+ xyz_head2 = xyz_head2[:num_frames]
313
+ rgb_head1 = rgb_head1[:num_frames]
314
+ rgb_head2 = rgb_head2[:num_frames]
315
+
316
+ # Add playback UI.
317
+ with server.gui.add_folder("Playback"):
318
+ gui_timestep = server.gui.add_slider(
319
+ "Timestep",
320
+ min=0,
321
+ max=num_frames - 1,
322
+ step=1,
323
+ initial_value=0,
324
+ disabled=True,
325
+ )
326
+ gui_next_frame = server.gui.add_button("Next Frame", disabled=True)
327
+ gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True)
328
+ gui_playing = server.gui.add_checkbox("Playing", True)
329
+ gui_framerate = server.gui.add_slider(
330
+ "FPS", min=1, max=60, step=0.1, initial_value=20
331
+ )
332
+ gui_framerate_options = server.gui.add_button_group(
333
+ "FPS options", ("10", "20", "30")
334
+ )
335
+ gui_show_all_frames = server.gui.add_checkbox("Show all frames", False)
336
+ gui_stride = server.gui.add_slider(
337
+ "Stride",
338
+ min=1,
339
+ max=num_frames,
340
+ step=1,
341
+ initial_value=5,
342
+ disabled=True, # Initially disabled
343
+ )
344
+
345
+ # Frame step buttons.
346
+ @gui_next_frame.on_click
347
+ def _(_) -> None:
348
+ gui_timestep.value = (gui_timestep.value + 1) % num_frames
349
+
350
+ @gui_prev_frame.on_click
351
+ def _(_) -> None:
352
+ gui_timestep.value = (gui_timestep.value - 1) % num_frames
353
+
354
+ # Disable frame controls when we're playing.
355
+ @gui_playing.on_update
356
+ def _(_) -> None:
357
+ gui_timestep.disabled = gui_playing.value or gui_show_all_frames.value
358
+ gui_next_frame.disabled = gui_playing.value or gui_show_all_frames.value
359
+ gui_prev_frame.disabled = gui_playing.value or gui_show_all_frames.value
360
+
361
+ # Set the framerate when we click one of the options.
362
+ @gui_framerate_options.on_click
363
+ def _(_) -> None:
364
+ gui_framerate.value = int(gui_framerate_options.value)
365
+
366
+ prev_timestep = gui_timestep.value
367
+
368
+ # Toggle frame visibility when the timestep slider changes.
369
+ @gui_timestep.on_update
370
+ def _(_) -> None:
371
+ nonlocal prev_timestep
372
+ current_timestep = gui_timestep.value
373
+ if not gui_show_all_frames.value:
374
+ with server.atomic():
375
+ if gui_show_head1.value:
376
+ frame_nodes_head1[current_timestep].visible = True
377
+ frame_nodes_head1[prev_timestep].visible = False
378
+ if gui_show_head2.value:
379
+ frame_nodes_head2[current_timestep].visible = True
380
+ frame_nodes_head2[prev_timestep].visible = False
381
+ prev_timestep = current_timestep
382
+ server.flush() # Optional!
383
+
384
+ # Show or hide all frames based on the checkbox.
385
+ @gui_show_all_frames.on_update
386
+ def _(_) -> None:
387
+ gui_stride.disabled = not gui_show_all_frames.value # Enable/disable stride slider
388
+ if gui_show_all_frames.value:
389
+ # Show frames with stride
390
+ stride = gui_stride.value
391
+ with server.atomic():
392
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
393
+ node1.visible = gui_show_head1.value and (i % stride == 0)
394
+ node2.visible = gui_show_head2.value and (i % stride == 0)
395
+ # Disable playback controls
396
+ gui_playing.disabled = True
397
+ gui_timestep.disabled = True
398
+ gui_next_frame.disabled = True
399
+ gui_prev_frame.disabled = True
400
+ else:
401
+ # Show only the current frame
402
+ current_timestep = gui_timestep.value
403
+ with server.atomic():
404
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
405
+ node1.visible = gui_show_head1.value and (i == current_timestep)
406
+ node2.visible = gui_show_head2.value and (i == current_timestep)
407
+ # Re-enable playback controls
408
+ gui_playing.disabled = False
409
+ gui_timestep.disabled = gui_playing.value
410
+ gui_next_frame.disabled = gui_playing.value
411
+ gui_prev_frame.disabled = gui_playing.value
412
+
413
+ # Update frame visibility when the stride changes.
414
+ @gui_stride.on_update
415
+ def _(_) -> None:
416
+ if gui_show_all_frames.value:
417
+ # Update frame visibility based on new stride
418
+ stride = gui_stride.value
419
+ with server.atomic():
420
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
421
+ node1.visible = gui_show_head1.value and (i % stride == 0)
422
+ node2.visible = gui_show_head2.value and (i % stride == 0)
423
+
424
+ # Load in frames.
425
+ server.scene.add_frame(
426
+ "/frames",
427
+ wxyz=tf.SO3.exp(onp.array([onp.pi / 2.0, 0.0, 0.0])).wxyz,
428
+ position=(0, 0, 0),
429
+ show_axes=False,
430
+ )
431
+ frame_nodes_head1: list[viser.FrameHandle] = []
432
+ frame_nodes_head2: list[viser.FrameHandle] = []
433
+
434
+ # Extract RGB components for tinting
435
+ blue_r, blue_g, blue_b = blue_rgb
436
+ red_r, red_g, red_b = red_rgb
437
+
438
+ # Create frames for each timestep
439
+ frame_nodes_head1 = []
440
+ frame_nodes_head2 = []
441
+ for i in tqdm(range(num_frames)):
442
+ # Process head1
443
+ if traj_3d_head1 is not None:
444
+ frame_nodes_head1.append(server.scene.add_frame(f"/frames/t{i}/head1", show_axes=False))
445
+ position = xyz_head1[i]
446
+ color = rgb_head1[i]
447
+ if conf_mask_head1 is not None:
448
+ position = position[conf_mask_head1[i]]
449
+ color = color[conf_mask_head1[i]]
450
+
451
+ # Add point cloud for head1 with optional blue tint
452
+ color_head1 = color.copy()
453
+ if gui_use_color_tint.value:
454
+ color_head1 *= blend_ratio
455
+ color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
456
+ color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
457
+ color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
458
+
459
+ server.scene.add_point_cloud(
460
+ name=f"/frames/t{i}/head1/point_cloud",
461
+ points=position[::downsample_factor],
462
+ colors=color_head1[::downsample_factor],
463
+ point_size=point_size,
464
+ point_shape="rounded",
465
+ )
466
+
467
+ # Process head2
468
+ if traj_3d_head2 is not None:
469
+ frame_nodes_head2.append(server.scene.add_frame(f"/frames/t{i}/head2", show_axes=False))
470
+ position = xyz_head2[i]
471
+ color = rgb_head2[i]
472
+ if conf_mask_head2 is not None:
473
+ position = position[conf_mask_head2[i]]
474
+ color = color[conf_mask_head2[i]]
475
+
476
+ # Add point cloud for head2 with optional red tint
477
+ color_head2 = color.copy()
478
+ if gui_use_color_tint.value:
479
+ color_head2 *= blend_ratio
480
+ color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
481
+ color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
482
+ color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
483
+
484
+ server.scene.add_point_cloud(
485
+ name=f"/frames/t{i}/head2/point_cloud",
486
+ points=position[::downsample_factor],
487
+ colors=color_head2[::downsample_factor],
488
+ point_size=point_size,
489
+ point_shape="rounded",
490
+ )
491
+
492
+ # Update visibility based on checkboxes
493
+ @gui_show_head1.on_update
494
+ def _(_) -> None:
495
+ with server.atomic():
496
+ for frame_node in frame_nodes_head1:
497
+ frame_node.visible = gui_show_head1.value and (
498
+ gui_show_all_frames.value
499
+ or (not gui_show_all_frames.value )
500
+ )
501
+
502
+ @gui_show_head2.on_update
503
+ def _(_) -> None:
504
+ with server.atomic():
505
+ for frame_node in frame_nodes_head2:
506
+ frame_node.visible = gui_show_head2.value and (
507
+ gui_show_all_frames.value
508
+ or (not gui_show_all_frames.value )
509
+ )
510
+
511
+ # Initial visibility
512
+ for i, (node1, node2) in enumerate(zip(frame_nodes_head1, frame_nodes_head2)):
513
+ if gui_show_all_frames.value:
514
+ node1.visible = gui_show_head1.value and (i % gui_stride.value == 0)
515
+ node2.visible = gui_show_head2.value and (i % gui_stride.value == 0)
516
+ else:
517
+ node1.visible = gui_show_head1.value and (i == gui_timestep.value)
518
+ node2.visible = gui_show_head2.value and (i == gui_timestep.value)
519
+
520
+ # Process and visualize trajectories for head1
521
+ if traj_3d_head1 is not None:
522
+ # Get points over time
523
+ xyz_head1_centered = xyz_head1.copy()
524
+
525
+ # Select points to visualize
526
+ num_points = xyz_head1.shape[1]
527
+ points_to_visualize = min(num_points, num_traj_points)
528
+
529
+ # Get the mask for the first frame and reshape it to match point cloud dimensions
530
+ if mid_anchor:
531
+ first_frame_mask = masks[num_frames//2].reshape(-1)
532
+ else:
533
+ first_frame_mask = masks[0].reshape(-1) #[#points, h]
534
+
535
+ # Calculate trajectory lengths for each point
536
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame] # Shape: (num_frames, num_points, 3)
537
+ traj_diffs = np.diff(trajectories, axis=0) # Differences between consecutive frames
538
+ traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0) # Sum of distances for each point
539
+
540
+ # Get points that are within the mask
541
+ valid_indices = np.where(first_frame_mask)[0]
542
+
543
+ if len(valid_indices) > 0:
544
+ # Calculate average trajectory length for masked points
545
+ masked_traj_lengths = traj_lengths[valid_indices]
546
+ avg_traj_length = np.mean(masked_traj_lengths)
547
+
548
+ if mask_folder is not None:
549
+ # do not filter points by trajectory length
550
+ long_traj_indices = valid_indices
551
+ else:
552
+ # Filter points by trajectory length
553
+ long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
554
+
555
+ # Randomly sample from the filtered points
556
+ if len(long_traj_indices) > 0:
557
+ # Random sampling without replacement
558
+ selected_indices = np.random.choice(
559
+ len(long_traj_indices),
560
+ min(points_to_visualize, len(long_traj_indices)),
561
+ replace=False
562
+ )
563
+ # Get the actual indices in their original order
564
+ valid_point_indices = long_traj_indices[np.sort(selected_indices)]
565
+ else:
566
+ valid_point_indices = np.array([])
567
+ else:
568
+ valid_point_indices = np.array([])
569
+
570
+ if len(valid_point_indices) > 0:
571
+ # Get trajectories for all valid points
572
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
573
+ N_point = trajectories.shape[1]
574
+ if color_code == "rainbow":
575
+ point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
576
+ elif color_code == "jet":
577
+ point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
578
+ # Modify the loop to handle frames less than fixed_length_traj
579
+ for i in range(traj_end_frame - traj_start_frame):
580
+ # Calculate the actual trajectory length for this frame
581
+ actual_length = min(fixed_length_traj, i + 1)
582
+
583
+ if actual_length > 1: # Need at least 2 points to form a line
584
+ # Get the appropriate slice of trajectory data
585
+ start_idx = max(0, i - actual_length + 1)
586
+ end_idx = i + 1
587
+
588
+ # Create line segments between consecutive frames
589
+ traj_slice = trajectories[start_idx:end_idx]
590
+ line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
591
+ line_points = line_points.reshape(-1, 2, 3)
592
+
593
+ # Create corresponding colors
594
+ line_colors = np.tile(point_colors, (actual_length-1, 1))
595
+ line_colors = np.stack([line_colors, line_colors], axis=1)
596
+
597
+ # Add line segments
598
+ server.scene.add_line_segments(
599
+ name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
600
+ points=line_points,
601
+ colors=line_colors,
602
+ line_width=traj_line_width,
603
+ visible=gui_show_trajectories.value
604
+ )
605
+
606
+ # Add trajectory controls functionality
607
+ @gui_show_trajectories.on_update
608
+ def _(_) -> None:
609
+ with server.atomic():
610
+ # Remove all existing trajectories
611
+ for i in range(num_frames):
612
+ try:
613
+ server.scene.remove_by_name(f"/frames/t{i}/head1/trajectory")
614
+ except KeyError:
615
+ pass
616
+
617
+ # Create new trajectories if enabled
618
+ if gui_show_trajectories.value and traj_3d_head1 is not None:
619
+ # Get the mask for the last frame and reshape it
620
+ last_frame_mask = masks[traj_end_frame-1].reshape(-1)
621
+
622
+ # Calculate trajectory lengths
623
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame]
624
+ traj_diffs = np.diff(trajectories, axis=0)
625
+ traj_lengths = np.sum(np.sqrt(np.sum(traj_diffs**2, axis=-1)), axis=0)
626
+
627
+ # Get points that are within the mask
628
+ valid_indices = np.where(last_frame_mask)[0]
629
+
630
+ if len(valid_indices) > 0:
631
+ # Filter by trajectory length
632
+ masked_traj_lengths = traj_lengths[valid_indices]
633
+ avg_traj_length = np.mean(masked_traj_lengths)
634
+ long_traj_indices = valid_indices[masked_traj_lengths >= avg_traj_length]
635
+
636
+ # Randomly sample from the filtered points
637
+ if len(long_traj_indices) > 0:
638
+ # Random sampling without replacement
639
+ selected_indices = np.random.choice(
640
+ len(long_traj_indices),
641
+ min(points_to_visualize, len(long_traj_indices)),
642
+ replace=False
643
+ )
644
+ # Get the actual indices in their original order
645
+ valid_point_indices = long_traj_indices[np.sort(selected_indices)]
646
+ else:
647
+ valid_point_indices = np.array([])
648
+ else:
649
+ valid_point_indices = np.array([])
650
+
651
+ if len(valid_point_indices) > 0:
652
+ # Get trajectories for all valid points
653
+ trajectories = xyz_head1_centered[traj_start_frame:traj_end_frame, valid_point_indices]
654
+ N_point = trajectories.shape[1]
655
+
656
+ if color_code == "rainbow":
657
+ point_colors = plt.cm.rainbow(np.linspace(0, 1, N_point))[:, :3]
658
+ elif color_code == "jet":
659
+ point_colors = plt.cm.jet(np.linspace(0, 1, N_point))[:, :3]
660
+
661
+ # Modify the loop to handle frames less than fixed_length_traj
662
+ for i in range(traj_end_frame - traj_start_frame):
663
+ # Calculate the actual trajectory length for this frame
664
+ actual_length = min(fixed_length_traj, i + 1)
665
+
666
+ if actual_length > 1: # Need at least 2 points to form a line
667
+ # Get the appropriate slice of trajectory data
668
+ start_idx = max(0, i - actual_length + 1)
669
+ end_idx = i + 1
670
+
671
+ # Create line segments between consecutive frames
672
+ traj_slice = trajectories[start_idx:end_idx]
673
+ line_points = np.stack([traj_slice[:-1], traj_slice[1:]], axis=2)
674
+ line_points = line_points.reshape(-1, 2, 3)
675
+
676
+ # Create corresponding colors
677
+ line_colors = np.tile(point_colors, (actual_length-1, 1))
678
+ line_colors = np.stack([line_colors, line_colors], axis=1)
679
+
680
+ # Add line segments
681
+ server.scene.add_line_segments(
682
+ name=f"/frames/t{i+traj_start_frame}/head1/trajectory",
683
+ points=line_points,
684
+ colors=line_colors,
685
+ line_width=traj_line_width,
686
+ visible=True
687
+ )
688
+
689
+ # Update color tinting when the checkbox changes
690
+ @gui_use_color_tint.on_update
691
+ def _(_) -> None:
692
+ with server.atomic():
693
+ for i in range(num_frames):
694
+ # Update head1 point cloud
695
+ if traj_3d_head1 is not None:
696
+ position = xyz_head1[i]
697
+ color = rgb_head1[i]
698
+ if conf_mask_head1 is not None:
699
+ position = position[conf_mask_head1[i]]
700
+ color = color[conf_mask_head1[i]]
701
+
702
+ color_head1 = color.copy()
703
+ if gui_use_color_tint.value:
704
+ color_head1 *= blend_ratio
705
+ color_head1[:, 0] = onp.clip(color_head1[:, 0] + blue_r * (1 - blend_ratio), 0, 1) # R
706
+ color_head1[:, 1] = onp.clip(color_head1[:, 1] + blue_g * (1 - blend_ratio), 0, 1) # G
707
+ color_head1[:, 2] = onp.clip(color_head1[:, 2] + blue_b * (1 - blend_ratio), 0, 1) # B
708
+
709
+ server.scene.remove_by_name(f"/frames/t{i}/head1/point_cloud")
710
+ server.scene.add_point_cloud(
711
+ name=f"/frames/t{i}/head1/point_cloud",
712
+ points=position[::downsample_factor],
713
+ colors=color_head1[::downsample_factor],
714
+ point_size=point_size,
715
+ point_shape="rounded",
716
+ )
717
+
718
+ # Update head2 point cloud
719
+ if traj_3d_head2 is not None:
720
+ position = xyz_head2[i]
721
+ color = rgb_head2[i]
722
+ if conf_mask_head2 is not None:
723
+ position = position[conf_mask_head2[i]]
724
+ color = color[conf_mask_head2[i]]
725
+
726
+ color_head2 = color.copy()
727
+ if gui_use_color_tint.value:
728
+ color_head2 *= blend_ratio
729
+ color_head2[:, 0] = onp.clip(color_head2[:, 0] + red_r * (1 - blend_ratio), 0, 1) # R
730
+ color_head2[:, 1] = onp.clip(color_head2[:, 1] + red_g * (1 - blend_ratio), 0, 1) # G
731
+ color_head2[:, 2] = onp.clip(color_head2[:, 2] + red_b * (1 - blend_ratio), 0, 1) # B
732
+
733
+ server.scene.remove_by_name(f"/frames/t{i}/head2/point_cloud")
734
+ server.scene.add_point_cloud(
735
+ name=f"/frames/t{i}/head2/point_cloud",
736
+ points=position[::downsample_factor],
737
+ colors=color_head2[::downsample_factor],
738
+ point_size=point_size,
739
+ point_shape="rounded",
740
+ )
741
+
742
+ # Initialize video preview
743
+ if raw_video is not None:
744
+ video_preview.image = process_video_frame(0)
745
+
746
+ # Update video preview when timestep changes
747
+ @gui_timestep.on_update
748
+ def _(_) -> None:
749
+ current_timestep = gui_timestep.value
750
+ if raw_video is not None:
751
+ video_preview.image = process_video_frame(current_timestep)
752
+
753
+ # Playback update loop.
754
+ log_memory_usage("before starting playback loop")
755
+
756
+ prev_timestep = gui_timestep.value
757
+ while True:
758
+ current_timestep = gui_timestep.value
759
+
760
+ # If timestep changes, update frame visibility
761
+ if current_timestep != prev_timestep:
762
+ with server.atomic():
763
+ # ... existing code ...
764
+
765
+ # Update video preview
766
+ if raw_video is not None:
767
+ video_preview.image = process_video_frame(current_timestep)
768
+
769
+ # Update in playback mode
770
+ if gui_playing.value and not gui_show_all_frames.value:
771
+ gui_timestep.value = (gui_timestep.value + 1) % num_frames
772
+
773
+ # Update video preview in playback mode
774
+ if raw_video is not None:
775
+ video_preview.image = process_video_frame(gui_timestep.value)
776
+
777
+ time.sleep(1.0 / gui_framerate.value)
778
+
779
+
780
+ if __name__ == "__main__":
781
+ tyro.cli(visualize_st4rtrack)
viser_proxy_manager.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import httpx
4
+ import viser
5
+ import websockets
6
+ from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
7
+ from fastapi.responses import Response
8
+
9
+
10
+ class ViserProxyManager:
11
+ """Manages Viser server instances for Gradio applications.
12
+
13
+ This class handles the creation, retrieval, and cleanup of Viser server instances,
14
+ as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
15
+
16
+ Args:
17
+ app: The FastAPI application to which the proxy routes will be added.
18
+ min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
19
+ These ports are used only for internal communication and don't need to be publicly exposed.
20
+ max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
21
+ These ports are used only for internal communication and don't need to be publicly exposed.
22
+ max_message_size: Maximum WebSocket message size in bytes. Defaults to 100MB.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ app: FastAPI,
28
+ min_local_port: int = 8000,
29
+ max_local_port: int = 9000,
30
+ max_message_size: int = 300 * 1024 * 1024, # 300MB default
31
+ ) -> None:
32
+ self._min_port = min_local_port
33
+ self._max_port = max_local_port
34
+ self._max_message_size = max_message_size
35
+ self._server_from_session_hash: dict[str, viser.ViserServer] = {}
36
+ self._last_port = self._min_port - 1 # Track last port tried
37
+
38
+ @app.get("/viser/{server_id}/{proxy_path:path}")
39
+ async def proxy(request: Request, server_id: str, proxy_path: str):
40
+ """Proxy HTTP requests to the appropriate Viser server."""
41
+ # Get the local port for this server ID
42
+ server = self._server_from_session_hash.get(server_id)
43
+ if server is None:
44
+ return Response(content="Server not found", status_code=404)
45
+
46
+ # Build target URL
47
+ if proxy_path:
48
+ path_suffix = f"/{proxy_path}"
49
+ else:
50
+ path_suffix = "/"
51
+
52
+ target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
53
+ if request.url.query:
54
+ target_url += f"?{request.url.query}"
55
+
56
+ # Forward request
57
+ async with httpx.AsyncClient() as client:
58
+ # Forward the original headers, but remove any problematic ones
59
+ headers = dict(request.headers)
60
+ headers.pop("host", None) # Remove host header to avoid conflicts
61
+ headers["accept-encoding"] = "identity" # Disable compression
62
+
63
+ proxied_req = client.build_request(
64
+ method=request.method,
65
+ url=target_url,
66
+ headers=headers,
67
+ content=await request.body(),
68
+ )
69
+ proxied_resp = await client.send(proxied_req, stream=True)
70
+
71
+ # Get response headers
72
+ response_headers = dict(proxied_resp.headers)
73
+
74
+ # Check if this is an HTML response
75
+ content = await proxied_resp.aread()
76
+ return Response(
77
+ content=content,
78
+ status_code=proxied_resp.status_code,
79
+ headers=response_headers,
80
+ )
81
+
82
+ # WebSocket Proxy
83
+ @app.websocket("/viser/{server_id}")
84
+ async def websocket_proxy(websocket: WebSocket, server_id: str):
85
+ """Proxy WebSocket connections to the appropriate Viser server."""
86
+ try:
87
+ await websocket.accept()
88
+
89
+ server = self._server_from_session_hash.get(server_id)
90
+ if server is None:
91
+ await websocket.close(code=1008, reason="Not Found")
92
+ return
93
+
94
+ # Determine target WebSocket URL
95
+ target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
96
+
97
+ if not target_ws_url:
98
+ await websocket.close(code=1008, reason="Not Found")
99
+ return
100
+
101
+ try:
102
+ # Connect to the target WebSocket with increased message size and timeout
103
+ async with websockets.connect(
104
+ target_ws_url,
105
+ max_size=self._max_message_size,
106
+ ping_interval=30, # Send ping every 30 seconds
107
+ ping_timeout=10, # Wait 10 seconds for pong response
108
+ close_timeout=5, # Wait 5 seconds for close handshake
109
+ ) as ws_target:
110
+ # Create tasks for bidirectional communication
111
+ async def forward_to_target():
112
+ """Forward messages from the client to the target WebSocket."""
113
+ try:
114
+ while True:
115
+ data = await websocket.receive_bytes()
116
+ await ws_target.send(data, text=False)
117
+ except WebSocketDisconnect:
118
+ try:
119
+ await ws_target.close()
120
+ except RuntimeError:
121
+ pass
122
+
123
+ async def forward_from_target():
124
+ """Forward messages from the target WebSocket to the client."""
125
+ try:
126
+ while True:
127
+ data = await ws_target.recv(decode=False)
128
+ await websocket.send_bytes(data)
129
+ except websockets.exceptions.ConnectionClosed:
130
+ try:
131
+ await websocket.close()
132
+ except RuntimeError:
133
+ pass
134
+
135
+ # Run both forwarding tasks concurrently
136
+ forward_task = asyncio.create_task(forward_to_target())
137
+ backward_task = asyncio.create_task(forward_from_target())
138
+
139
+ # Wait for either task to complete (which means a connection was closed)
140
+ done, pending = await asyncio.wait(
141
+ [forward_task, backward_task],
142
+ return_when=asyncio.FIRST_COMPLETED,
143
+ )
144
+
145
+ # Cancel the remaining task
146
+ for task in pending:
147
+ task.cancel()
148
+
149
+ except websockets.exceptions.ConnectionClosedError as e:
150
+ print(f"WebSocket connection closed with error: {e}")
151
+ await websocket.close(code=1011, reason="Connection to target closed")
152
+
153
+ except Exception as e:
154
+ print(f"WebSocket proxy error: {e}")
155
+ try:
156
+ await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length
157
+ except:
158
+ pass # Already closed
159
+
160
+ def start_server(self, server_id: str) -> viser.ViserServer:
161
+ """Start a new Viser server and associate it with the given server ID.
162
+
163
+ Finds an available port within the configured min_local_port and max_local_port range.
164
+ These ports are used only for internal communication and don't need to be publicly exposed.
165
+
166
+ Args:
167
+ server_id: The unique identifier to associate with the new server.
168
+
169
+ Returns:
170
+ The newly created Viser server instance.
171
+
172
+ Raises:
173
+ RuntimeError: If no free ports are available in the configured range.
174
+ """
175
+ import socket
176
+
177
+ # Start searching from the last port + 1 (with wraparound)
178
+ port_range_size = self._max_port - self._min_port + 1
179
+ start_port = (
180
+ (self._last_port + 1 - self._min_port) % port_range_size
181
+ ) + self._min_port
182
+
183
+ # Try each port once
184
+ for offset in range(port_range_size):
185
+ port = (
186
+ (start_port - self._min_port + offset) % port_range_size
187
+ ) + self._min_port
188
+ try:
189
+ # Check if port is available by attempting to bind to it
190
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
191
+ s.bind(("127.0.0.1", port))
192
+ # Port is available, create server with this port
193
+ server = viser.ViserServer(port=port)
194
+ self._server_from_session_hash[server_id] = server
195
+ self._last_port = port
196
+ return server
197
+ except OSError:
198
+ # Port is in use, try the next one
199
+ continue
200
+
201
+ # If we get here, no ports were available
202
+ raise RuntimeError(
203
+ f"No available local ports in range {self._min_port}-{self._max_port}"
204
+ )
205
+
206
+ def get_server(self, server_id: str) -> viser.ViserServer:
207
+ """Retrieve a Viser server instance by its ID.
208
+
209
+ Args:
210
+ server_id: The unique identifier of the server to retrieve.
211
+
212
+ Returns:
213
+ The Viser server instance associated with the given ID.
214
+ """
215
+ return self._server_from_session_hash[server_id]
216
+
217
+ def stop_server(self, server_id: str) -> None:
218
+ """Stop a Viser server and remove it from the manager.
219
+
220
+ Args:
221
+ server_id: The unique identifier of the server to stop.
222
+ """
223
+ self._server_from_session_hash.pop(server_id).stop()