Update visualizer.py
Browse files- visualizer.py +495 -332
visualizer.py
CHANGED
@@ -4,390 +4,553 @@
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
|
|
7 |
import os
|
8 |
-
|
|
|
9 |
import imageio
|
|
|
|
|
10 |
import torch
|
|
|
11 |
|
12 |
from matplotlib import cm
|
13 |
-
import torch.nn.functional as F
|
14 |
-
import torchvision.transforms as transforms
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
from PIL import Image, ImageDraw
|
17 |
|
18 |
|
19 |
-
def
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
draw.ellipse(
|
41 |
-
[left_up_point, right_down_point],
|
42 |
-
fill=tuple(color) if visible else None,
|
43 |
-
outline=tuple(color),
|
44 |
-
)
|
45 |
-
return rgb
|
46 |
-
|
47 |
-
|
48 |
-
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
49 |
-
draw = ImageDraw.Draw(rgb)
|
50 |
-
draw.line(
|
51 |
-
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
52 |
-
fill=tuple(color),
|
53 |
-
width=linewidth,
|
54 |
-
)
|
55 |
-
return rgb
|
56 |
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
class Visualizer:
|
63 |
def __init__(
|
64 |
self,
|
65 |
-
|
66 |
-
|
67 |
-
pad_value: int = 0,
|
68 |
fps: int = 10,
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
):
|
74 |
-
self.
|
75 |
-
self.
|
76 |
-
if mode == "rainbow":
|
77 |
-
self.color_map = cm.get_cmap("gist_rainbow")
|
78 |
-
elif mode == "cool":
|
79 |
-
self.color_map = cm.get_cmap(mode)
|
80 |
-
self.show_first_frame = show_first_frame
|
81 |
-
self.grayscale = grayscale
|
82 |
-
self.tracks_leave_trace = tracks_leave_trace
|
83 |
-
self.pad_value = pad_value
|
84 |
-
self.linewidth = linewidth
|
85 |
self.fps = fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def visualize(
|
88 |
self,
|
89 |
video: torch.Tensor, # (B,T,C,H,W)
|
90 |
tracks: torch.Tensor, # (B,T,N,2)
|
91 |
-
visibility: torch.Tensor = None, # (B,
|
92 |
-
|
93 |
-
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
94 |
filename: str = "video",
|
95 |
-
|
96 |
-
step: int = 0,
|
97 |
-
query_frame=0,
|
98 |
save_video: bool = True,
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
video,
|
110 |
-
(self.
|
111 |
-
"constant",
|
112 |
-
255,
|
113 |
)
|
114 |
-
color_alpha = int(opacity * 255)
|
115 |
-
tracks = tracks + self.pad_value
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
video = transform(video)
|
120 |
-
video = video.repeat(1, 1, 3, 1, 1)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
125 |
visibility=visibility,
|
126 |
-
|
127 |
-
gt_tracks=gt_tracks,
|
128 |
query_frame=query_frame,
|
129 |
-
|
130 |
-
color_alpha=color_alpha,
|
131 |
)
|
|
|
|
|
132 |
if save_video:
|
133 |
-
self.save_video(
|
134 |
-
|
135 |
-
|
136 |
-
def save_video(self, video, filename, writer=None, step=0):
|
137 |
-
if writer is not None:
|
138 |
-
writer.add_video(
|
139 |
-
filename,
|
140 |
-
video.to(torch.uint8),
|
141 |
-
global_step=step,
|
142 |
-
fps=self.fps,
|
143 |
-
)
|
144 |
-
else:
|
145 |
-
os.makedirs(self.save_dir, exist_ok=True)
|
146 |
-
wide_list = list(video.unbind(1))
|
147 |
-
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
148 |
|
149 |
-
|
150 |
-
|
151 |
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
158 |
|
159 |
-
|
160 |
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
def draw_tracks_on_video(
|
164 |
self,
|
165 |
video: torch.Tensor,
|
166 |
tracks: torch.Tensor,
|
167 |
visibility: torch.Tensor = None,
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
)
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
229 |
-
|
230 |
-
else:
|
231 |
-
# color changes with segm class
|
232 |
-
segm_mask = segm_mask.cpu()
|
233 |
-
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
234 |
-
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
235 |
-
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
236 |
-
vector_colors = np.repeat(color[None], T, axis=0)
|
237 |
-
|
238 |
-
# draw tracks
|
239 |
-
if self.tracks_leave_trace != 0:
|
240 |
-
for t in range(query_frame + 1, T):
|
241 |
-
first_ind = (
|
242 |
-
max(0, t - self.tracks_leave_trace)
|
243 |
-
if self.tracks_leave_trace >= 0
|
244 |
-
else 0
|
245 |
)
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
curr_tracks = curr_tracks[:, segm_mask > 0]
|
256 |
-
curr_colors = curr_colors[:, segm_mask > 0]
|
257 |
-
|
258 |
-
res_video[t] = self._draw_pred_tracks(
|
259 |
-
res_video[t],
|
260 |
-
curr_tracks,
|
261 |
-
curr_colors,
|
262 |
)
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
)
|
267 |
|
268 |
-
|
269 |
-
for t in range(T):
|
270 |
-
img = Image.fromarray(np.uint8(res_video[t]))
|
271 |
-
for i in range(N):
|
272 |
-
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
273 |
-
visibile = True
|
274 |
-
if visibility is not None:
|
275 |
-
visibile = visibility[0, t, i]
|
276 |
-
if coord[0] != 0 and coord[1] != 0:
|
277 |
-
if not compensate_for_camera_motion or (
|
278 |
-
compensate_for_camera_motion and segm_mask[i] > 0
|
279 |
-
):
|
280 |
-
# img = draw_circle(
|
281 |
-
# img,
|
282 |
-
# coord=coord,
|
283 |
-
# radius=int(self.linewidth * 2),
|
284 |
-
# color=vector_colors[t, i].astype(int),
|
285 |
-
# visible=visibile,
|
286 |
-
# color_alpha=color_alpha,
|
287 |
-
# )
|
288 |
-
|
289 |
-
# coord_ = coord[t,i]
|
290 |
-
# draw a red cross
|
291 |
-
# if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
292 |
-
if visibile:
|
293 |
-
length = self.linewidth * 3
|
294 |
-
coord_y = (int(coord[0]) + length, int(coord[1]) + length)
|
295 |
-
coord_x = (int(coord[0]) - length, int(coord[1]) - length)
|
296 |
-
rgb = draw_line(
|
297 |
-
img,
|
298 |
-
coord_y,
|
299 |
-
coord_x,
|
300 |
-
vector_colors[t, i].astype(int),
|
301 |
-
self.linewidth,
|
302 |
-
)
|
303 |
-
coord_y = (int(coord[0]) - length, int(coord[1]) + length)
|
304 |
-
coord_x = (int(coord[0]) + length, int(coord[1]) - length)
|
305 |
-
rgb = draw_line(
|
306 |
-
img,
|
307 |
-
coord_y,
|
308 |
-
coord_x,
|
309 |
-
vector_colors[t, i].astype(int),
|
310 |
-
self.linewidth,
|
311 |
-
)
|
312 |
-
res_video[t] = np.array(img)
|
313 |
-
|
314 |
-
# construct the final rgb sequence
|
315 |
-
if self.show_first_frame > 0:
|
316 |
-
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
317 |
-
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
318 |
-
|
319 |
-
def _draw_pred_tracks(
|
320 |
self,
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
):
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
)
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
self,
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
):
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
382 |
)
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
|
|
|
|
|
|
391 |
)
|
392 |
-
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
+
|
8 |
import os
|
9 |
+
from typing import List
|
10 |
+
|
11 |
import imageio
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
|
17 |
from matplotlib import cm
|
|
|
|
|
|
|
18 |
from PIL import Image, ImageDraw
|
19 |
|
20 |
|
21 |
+
def draw_circle_on_image(
|
22 |
+
image: Image,
|
23 |
+
center: tuple,
|
24 |
+
radius: int,
|
25 |
+
color: tuple = (255, 0, 0),
|
26 |
+
visible: bool = True,
|
27 |
+
alpha: int = None,
|
28 |
+
) -> Image:
|
29 |
+
"""Draw a circle on a PIL Image.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
image: PIL Image to draw on
|
33 |
+
center: (x,y) coordinates of circle center
|
34 |
+
radius: Radius of circle in pixels
|
35 |
+
color: RGB color tuple
|
36 |
+
visible: Whether to fill the circle
|
37 |
+
alpha: Optional alpha value for transparency
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Modified PIL Image
|
41 |
+
"""
|
42 |
+
draw = ImageDraw.Draw(image, 'RGBA') # Enable alpha channel
|
43 |
+
# Use float coordinates for smoother rendering
|
44 |
+
bbox = [
|
45 |
+
(center[0] - radius, center[1] - radius),
|
46 |
+
(center[0] + radius, center[1] + radius),
|
47 |
+
]
|
48 |
+
color = tuple(list(color) + [alpha if alpha is not None else 255])
|
49 |
+
|
50 |
+
# Use anti-aliasing by drawing a slightly larger circle underneath
|
51 |
+
if visible:
|
52 |
+
# Draw a slightly larger background circle for anti-aliasing
|
53 |
+
larger_bbox = [
|
54 |
+
(center[0] - radius - 0.5, center[1] - radius - 0.5),
|
55 |
+
(center[0] + radius + 0.5, center[1] + radius + 0.5),
|
56 |
+
]
|
57 |
+
draw.ellipse(larger_bbox, fill=tuple(list(color[:-1]) + [int(color[-1] * 0.5)]))
|
58 |
+
|
59 |
+
draw.ellipse(bbox, fill=tuple(color) if visible else None, outline=tuple(color))
|
60 |
+
return image
|
61 |
+
|
62 |
+
|
63 |
+
def draw_line_segment(
|
64 |
+
image: Image, start: tuple, end: tuple, color: tuple, width: int
|
65 |
+
) -> Image:
|
66 |
+
"""Draw a line on a PIL Image.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
image: PIL Image to draw on
|
70 |
+
start: (x,y) coordinates of line start
|
71 |
+
end: (x,y) coordinates of line end
|
72 |
+
color: RGB color tuple
|
73 |
+
width: Line width in pixels
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Modified PIL Image
|
77 |
+
"""
|
78 |
+
draw = ImageDraw.Draw(image)
|
79 |
+
draw.line((start[0], start[1], end[0], end[1]), fill=tuple(color), width=width)
|
80 |
+
return image
|
81 |
+
|
82 |
+
|
83 |
+
def blend_images(
|
84 |
+
image1: np.ndarray, alpha: float, image2: np.ndarray, beta: float, gamma: float
|
85 |
+
) -> np.ndarray:
|
86 |
+
"""Blend two images with weights.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
image1: First image array
|
90 |
+
alpha: Weight of first image
|
91 |
+
image2: Second image array
|
92 |
+
beta: Weight of second image
|
93 |
+
gamma: Scalar added to weighted sum
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Blended uint8 image array
|
97 |
+
"""
|
98 |
+
return (image1 * alpha + image2 * beta + gamma).astype("uint8")
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
class Visualizer:
|
102 |
+
"""A class for visualizing point tracks on videos.
|
103 |
|
104 |
+
Handles drawing tracked points and their trajectories on video frames.
|
|
|
105 |
|
106 |
+
Args:
|
107 |
+
output_dir: Directory to save output visualizations
|
108 |
+
padding: Padding to add around video frames in pixels
|
109 |
+
fps: Frames per second for output video
|
110 |
+
colormap: Color scheme for tracks ('rainbow' or 'spring')
|
111 |
+
line_width: Width of track lines in pixels
|
112 |
+
initial_frame_repeat: Number of times to repeat first frame
|
113 |
+
track_history_length: How many past frames to show tracks for (0=current only, -1=all)
|
114 |
+
"""
|
115 |
|
|
|
116 |
def __init__(
|
117 |
self,
|
118 |
+
output_dir: str = "./results",
|
119 |
+
padding: int = 0,
|
|
|
120 |
fps: int = 10,
|
121 |
+
colormap: str = "rainbow",
|
122 |
+
line_width: int = 2,
|
123 |
+
initial_frame_repeat: int = 10,
|
124 |
+
track_history_length: int = 0,
|
125 |
):
|
126 |
+
self.output_dir = output_dir
|
127 |
+
self.padding = padding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
self.fps = fps
|
129 |
+
self.line_width = line_width
|
130 |
+
self.initial_frame_repeat = initial_frame_repeat
|
131 |
+
self.track_history_length = track_history_length
|
132 |
+
|
133 |
+
# Set up colormap for track visualization
|
134 |
+
self.colormap = colormap
|
135 |
+
if colormap not in ["rainbow", "spring"]:
|
136 |
+
raise ValueError("Colormap must be 'rainbow' or 'spring'")
|
137 |
+
|
138 |
+
self.color_mapper = cm.get_cmap(
|
139 |
+
"gist_rainbow" if colormap == "rainbow" else "spring"
|
140 |
+
)
|
141 |
|
142 |
def visualize(
|
143 |
self,
|
144 |
video: torch.Tensor, # (B,T,C,H,W)
|
145 |
tracks: torch.Tensor, # (B,T,N,2)
|
146 |
+
visibility: torch.Tensor = None, # (B,T,N,1) bool
|
147 |
+
segmentation: torch.Tensor = None, # (B,1,H,W)
|
|
|
148 |
filename: str = "video",
|
149 |
+
query_frame: int = 0,
|
|
|
|
|
150 |
save_video: bool = True,
|
151 |
+
point_opacity: float = 1.0,
|
152 |
+
) -> torch.Tensor:
|
153 |
+
"""Visualize tracked points and their trajectories on video frames.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
video: Input video tensor of shape (B,T,C,H,W)
|
157 |
+
tracks: Point track coordinates of shape (B,T,N,2)
|
158 |
+
visibility: Optional visibility mask of shape (B,T,N,1)
|
159 |
+
segmentation: Optional segmentation mask of shape (B,1,H,W)
|
160 |
+
filename: Output filename for saved video
|
161 |
+
query_frame: Frame index to use for color assignment
|
162 |
+
save_video: Whether to save visualization video
|
163 |
+
point_opacity: Opacity value for track points (0-1)
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor containing visualization frames
|
167 |
+
"""
|
168 |
+
# Process segmentation if provided
|
169 |
+
if segmentation is not None:
|
170 |
+
coords = tracks[0, query_frame].round().long()
|
171 |
+
segmentation = segmentation[0, query_frame][
|
172 |
+
coords[:, 1], coords[:, 0]
|
173 |
+
].long()
|
174 |
+
|
175 |
+
# Add padding to video frames
|
176 |
+
padded_video = F.pad(
|
177 |
video,
|
178 |
+
(self.padding, self.padding, self.padding, self.padding),
|
179 |
+
mode="constant",
|
180 |
+
value=255,
|
181 |
)
|
|
|
|
|
182 |
|
183 |
+
# Convert opacity to integer value
|
184 |
+
opacity_value = min(max(int(point_opacity * 255), 0), 255)
|
|
|
|
|
185 |
|
186 |
+
# Adjust track coordinates for padding
|
187 |
+
padded_tracks = tracks + self.padding
|
188 |
+
|
189 |
+
# Generate visualization frames
|
190 |
+
output_video = self.draw_tracks_on_video(
|
191 |
+
video=padded_video,
|
192 |
+
tracks=padded_tracks,
|
193 |
visibility=visibility,
|
194 |
+
segmentation=segmentation,
|
|
|
195 |
query_frame=query_frame,
|
196 |
+
opacity=opacity_value,
|
|
|
197 |
)
|
198 |
+
|
199 |
+
# Save video if requested
|
200 |
if save_video:
|
201 |
+
self.save_video(output_video, filename=filename)
|
202 |
+
|
203 |
+
return output_video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
+
def save_video(self, video: torch.Tensor, filename: str):
|
206 |
+
"""Save video tensor as MP4 file.
|
207 |
|
208 |
+
Args:
|
209 |
+
video: Video tensor of shape (B,T,C,H,W)
|
210 |
+
filename: Output filename without extension
|
211 |
+
"""
|
212 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
213 |
|
214 |
+
# Extract frames from video tensor
|
215 |
+
frames = [
|
216 |
+
frame[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
|
217 |
+
for frame in video.unbind(1)
|
218 |
+
]
|
219 |
|
220 |
+
output_path = os.path.join(self.output_dir, f"{filename}.mp4")
|
221 |
|
222 |
+
try:
|
223 |
+
with imageio.get_writer(output_path, fps=self.fps, quality=8) as writer:
|
224 |
+
# Write frames excluding padding frames
|
225 |
+
for frame in frames[2:-1]:
|
226 |
+
writer.append_data(frame)
|
227 |
+
|
228 |
+
print(f"Successfully saved video to {output_path}")
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
print(f"Error saving video to {output_path}: {str(e)}")
|
232 |
|
233 |
def draw_tracks_on_video(
|
234 |
self,
|
235 |
video: torch.Tensor,
|
236 |
tracks: torch.Tensor,
|
237 |
visibility: torch.Tensor = None,
|
238 |
+
segmentation: torch.Tensor = None,
|
239 |
+
query_frame: int = 0,
|
240 |
+
opacity: int = 255,
|
241 |
+
) -> torch.Tensor:
|
242 |
+
"""Draw tracks on video frames.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
video: Video tensor of shape (B,T,C,H,W)
|
246 |
+
tracks: Track coordinates tensor of shape (B,T,N,2)
|
247 |
+
visibility: Optional visibility mask of shape (B,T,N)
|
248 |
+
segmentation: Optional segmentation mask for coloring
|
249 |
+
query_frame: Frame index to use for rainbow coloring
|
250 |
+
opacity: Opacity value for track points (0-255)
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
Video tensor with visualized tracks of shape (1,T,3,H,W)
|
254 |
+
"""
|
255 |
+
# Validate input dimensions
|
256 |
+
_, num_frames, channels, _, _ = video.shape
|
257 |
+
_, _, num_points, dims = tracks.shape
|
258 |
+
assert dims == 2 and channels == 3, "Invalid input dimensions"
|
259 |
+
|
260 |
+
# Convert tensors to numpy arrays but keep as float
|
261 |
+
video_np = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()
|
262 |
+
tracks_np = tracks[0].detach().cpu().numpy()
|
263 |
+
|
264 |
+
# Create output frame buffer
|
265 |
+
output_frames = [frame.copy() for frame in video_np]
|
266 |
+
|
267 |
+
# Assign colors to tracks based on segmentation or position
|
268 |
+
track_colors = (
|
269 |
+
self._assign_segmentation_colors(
|
270 |
+
tracks_np, segmentation, num_frames, num_points
|
271 |
+
)
|
272 |
+
if segmentation is not None
|
273 |
+
else self._assign_track_colors(
|
274 |
+
tracks_np, query_frame, num_frames, num_points
|
275 |
+
)
|
276 |
+
)
|
277 |
+
|
278 |
+
# Draw track history lines if enabled
|
279 |
+
if self.track_history_length != 0:
|
280 |
+
output_frames = self._draw_track_lines(
|
281 |
+
output_frames, tracks_np, track_colors, query_frame, num_frames
|
282 |
+
)
|
283 |
+
|
284 |
+
# Draw track points with visibility and opacity
|
285 |
+
output_frames = self._draw_track_points(
|
286 |
+
output_frames, tracks_np, track_colors, visibility, opacity
|
287 |
+
)
|
288 |
|
289 |
+
# Add initial frame repeats for better visualization
|
290 |
+
if self.initial_frame_repeat > 0:
|
291 |
+
output_frames = [
|
292 |
+
output_frames[0]
|
293 |
+
] * self.initial_frame_repeat + output_frames[1:]
|
294 |
+
|
295 |
+
# Convert back to torch tensor
|
296 |
+
return (
|
297 |
+
torch.from_numpy(np.stack(output_frames)).permute(0, 3, 1, 2)[None].byte()
|
298 |
+
)
|
299 |
+
|
300 |
+
def _assign_track_colors(
|
301 |
+
self, tracks: np.ndarray, query_frame: int, num_frames: int, num_points: int
|
302 |
+
) -> np.ndarray:
|
303 |
+
"""Assigns colors to tracks based on either rainbow mapping of y-coordinates or time-based coloring.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
tracks: Track coordinates array of shape (num_frames, num_points, 2)
|
307 |
+
query_frame: Frame index to use for rainbow coloring
|
308 |
+
num_frames: Total number of frames
|
309 |
+
num_points: Number of tracked points
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
Array of track colors with shape (num_frames, num_points, 3)
|
313 |
+
"""
|
314 |
+
track_colors = np.zeros((num_frames, num_points, 3))
|
315 |
+
|
316 |
+
if self.colormap == "rainbow":
|
317 |
+
# Normalize y-coordinates to [0,1] range for rainbow coloring
|
318 |
+
y_coords = tracks[query_frame, :, 1]
|
319 |
+
y_min, y_max = y_coords.min(), y_coords.max()
|
320 |
+
if y_min == y_max:
|
321 |
+
y_max = y_min + 1 # Avoid division by zero
|
322 |
+
norm = plt.Normalize(y_min, y_max)
|
323 |
+
|
324 |
+
# Assign colors based on normalized y-coordinate
|
325 |
+
for point_idx in range(num_points):
|
326 |
+
query_idx = (
|
327 |
+
query_frame[point_idx]
|
328 |
+
if isinstance(query_frame, torch.Tensor)
|
329 |
+
else query_frame
|
330 |
)
|
331 |
+
color = (
|
332 |
+
np.array(
|
333 |
+
self.color_mapper(norm(tracks[query_idx, point_idx, 1]))[:3]
|
334 |
+
)[None]
|
335 |
+
* 255
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
)
|
337 |
+
track_colors[:, point_idx] = np.repeat(color, num_frames, axis=0)
|
338 |
+
else:
|
339 |
+
# Assign colors that vary smoothly with time
|
340 |
+
for frame_idx in range(num_frames):
|
341 |
+
color = (
|
342 |
+
np.array(self.color_mapper(frame_idx / max(1, num_frames - 1))[:3])[
|
343 |
+
None
|
344 |
+
]
|
345 |
+
* 255
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
)
|
347 |
+
track_colors[frame_idx] = np.repeat(color, num_points, axis=0)
|
348 |
+
|
349 |
+
return track_colors.astype(np.uint8)
|
|
|
350 |
|
351 |
+
def _assign_segmentation_colors(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
self,
|
353 |
+
tracks: np.ndarray,
|
354 |
+
segmentation: torch.Tensor,
|
355 |
+
num_frames: int,
|
356 |
+
num_points: int,
|
357 |
+
) -> np.ndarray:
|
358 |
+
"""Assigns colors to tracks based on segmentation masks and colormap.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
tracks: Track coordinates array of shape (num_frames, num_points, 2)
|
362 |
+
segmentation: Binary segmentation mask of shape (num_points,)
|
363 |
+
num_frames: Total number of frames
|
364 |
+
num_points: Number of tracked points
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
Array of track colors with shape (num_frames, num_points, 3)
|
368 |
+
"""
|
369 |
+
track_colors = np.zeros((num_frames, num_points, 3))
|
370 |
+
|
371 |
+
if self.colormap == "rainbow":
|
372 |
+
# Set background points to white
|
373 |
+
background_mask = segmentation <= 0
|
374 |
+
track_colors[:, background_mask, :] = 255
|
375 |
+
|
376 |
+
# Color foreground points based on y-coordinate
|
377 |
+
foreground_mask = segmentation > 0
|
378 |
+
if torch.any(foreground_mask):
|
379 |
+
y_coords = tracks[0, foreground_mask, 1]
|
380 |
+
y_min, y_max = y_coords.min(), y_coords.max()
|
381 |
+
if y_min == y_max:
|
382 |
+
y_max = y_min + 1 # Avoid division by zero
|
383 |
+
norm = plt.Normalize(y_min, y_max)
|
384 |
+
|
385 |
+
for point_idx in range(num_points):
|
386 |
+
if segmentation[point_idx] > 0:
|
387 |
+
color = (
|
388 |
+
np.array(
|
389 |
+
self.color_mapper(norm(tracks[0, point_idx, 1]))[:3]
|
390 |
+
)[None]
|
391 |
+
* 255
|
392 |
)
|
393 |
+
track_colors[:, point_idx] = np.repeat(
|
394 |
+
color, num_frames, axis=0
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
# Binary coloring based on segmentation
|
398 |
+
segmentation = segmentation.cpu()
|
399 |
+
colors = np.zeros((num_points, 3), dtype=np.float32)
|
400 |
+
colors[segmentation > 0] = (
|
401 |
+
np.array(self.color_mapper(1.0)[:3]) * 255.0
|
402 |
+
) # Foreground
|
403 |
+
colors[segmentation <= 0] = (
|
404 |
+
np.array(self.color_mapper(0.0)[:3]) * 255.0
|
405 |
+
) # Background
|
406 |
+
track_colors = np.repeat(colors[None], num_frames, axis=0)
|
407 |
+
|
408 |
+
return track_colors.astype(np.uint8)
|
409 |
+
|
410 |
+
def _draw_track_lines(
|
411 |
+
self,
|
412 |
+
frames: List[np.ndarray],
|
413 |
+
tracks: np.ndarray,
|
414 |
+
track_colors: np.ndarray,
|
415 |
+
query_frame: int,
|
416 |
+
num_frames: int,
|
417 |
+
) -> List[np.ndarray]:
|
418 |
+
"""Draw track lines showing point trajectories over time.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
frames: List of video frames to draw on
|
422 |
+
tracks: Array of track coordinates (num_frames, num_points, 2)
|
423 |
+
track_colors: Array of track colors (num_frames, num_points, 3)
|
424 |
+
query_frame: Frame index where tracking starts
|
425 |
+
num_frames: Total number of frames
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
List of frames with track lines drawn
|
429 |
+
"""
|
430 |
+
# Draw tracks starting from query frame
|
431 |
+
for frame_idx in range(query_frame + 1, num_frames):
|
432 |
+
# Get track history based on history length setting
|
433 |
+
start_idx = (
|
434 |
+
max(0, frame_idx - self.track_history_length)
|
435 |
+
if self.track_history_length >= 0
|
436 |
+
else 0
|
437 |
+
)
|
438 |
|
439 |
+
# Extract relevant track segments and colors
|
440 |
+
curr_tracks = tracks[start_idx : frame_idx + 1]
|
441 |
+
curr_colors = track_colors[start_idx : frame_idx + 1]
|
442 |
+
|
443 |
+
# Draw track segments on current frame
|
444 |
+
frames[frame_idx] = self._draw_track_segments(
|
445 |
+
frames[frame_idx], curr_tracks, curr_colors
|
446 |
+
)
|
447 |
+
|
448 |
+
return frames
|
449 |
+
|
450 |
+
def _draw_track_segments(
|
451 |
self,
|
452 |
+
frame: np.ndarray,
|
453 |
+
tracks: np.ndarray,
|
454 |
+
colors: np.ndarray,
|
455 |
+
) -> np.ndarray:
|
456 |
+
"""Draw track segments showing point trajectories between consecutive frames.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
frame: Video frame to draw on
|
460 |
+
tracks: Array of track coordinates (num_segments, num_points, 2)
|
461 |
+
colors: Array of track colors (num_segments, num_points, 3)
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
Frame with track segments drawn
|
465 |
+
"""
|
466 |
+
num_segments, num_points, _ = tracks.shape
|
467 |
+
frame_img = Image.fromarray(np.uint8(frame))
|
468 |
+
|
469 |
+
for segment_idx in range(num_segments - 1):
|
470 |
+
segment_color = colors[segment_idx]
|
471 |
+
original = frame_img.copy()
|
472 |
+
|
473 |
+
# Use cubic falloff for track history opacity
|
474 |
+
alpha = (segment_idx / num_segments) ** 3
|
475 |
+
|
476 |
+
valid_points = ~np.isclose(tracks[segment_idx], 0).all(axis=1)
|
477 |
+
|
478 |
+
for point_idx in range(num_points):
|
479 |
+
if valid_points[point_idx]:
|
480 |
+
start = (
|
481 |
+
tracks[segment_idx, point_idx, 0],
|
482 |
+
tracks[segment_idx, point_idx, 1],
|
483 |
)
|
484 |
+
end = (
|
485 |
+
tracks[segment_idx + 1, point_idx, 0],
|
486 |
+
tracks[segment_idx + 1, point_idx, 1],
|
487 |
+
)
|
488 |
+
|
489 |
+
frame_img = draw_line_segment(
|
490 |
+
frame_img,
|
491 |
+
start,
|
492 |
+
end,
|
493 |
+
segment_color[point_idx].astype(int),
|
494 |
+
self.line_width,
|
495 |
)
|
496 |
+
|
497 |
+
if self.track_history_length > 0:
|
498 |
+
frame_img = Image.fromarray(
|
499 |
+
blend_images(
|
500 |
+
np.array(frame_img), alpha, np.array(original), 1 - alpha, 0
|
501 |
+
)
|
502 |
+
)
|
503 |
+
|
504 |
+
return np.array(frame_img)
|
505 |
+
|
506 |
+
def _draw_track_points(
|
507 |
+
self,
|
508 |
+
frames: List[np.ndarray],
|
509 |
+
tracks: np.ndarray,
|
510 |
+
track_colors: np.ndarray,
|
511 |
+
visibility: torch.Tensor,
|
512 |
+
opacity: int,
|
513 |
+
) -> List[np.ndarray]:
|
514 |
+
"""Draw tracked points on each frame with circles.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
frames: List of video frames to draw on
|
518 |
+
tracks: Array of track coordinates (num_frames, num_points, 2)
|
519 |
+
track_colors: Array of track colors (num_frames, num_points, 3)
|
520 |
+
visibility: Tensor indicating point visibility per frame
|
521 |
+
opacity: Opacity value for drawing points
|
522 |
+
|
523 |
+
Returns:
|
524 |
+
List of frames with track points drawn
|
525 |
+
"""
|
526 |
+
frame_imgs = [Image.fromarray(np.uint8(frame)) for frame in frames]
|
527 |
+
|
528 |
+
# Use more precise validation of points
|
529 |
+
valid_points = ~np.isclose(tracks, 0).all(axis=2)
|
530 |
+
|
531 |
+
for frame_idx, frame_img in enumerate(frame_imgs):
|
532 |
+
frame_visibility = (
|
533 |
+
np.ones(tracks.shape[1], dtype=bool)
|
534 |
+
if visibility is None
|
535 |
+
else visibility[0, frame_idx].cpu().numpy()
|
536 |
+
)
|
537 |
+
|
538 |
+
points_to_draw = np.logical_and(valid_points[frame_idx], frame_visibility)
|
539 |
+
|
540 |
+
for point_idx in np.where(points_to_draw)[0]:
|
541 |
+
# Keep coordinates as floats
|
542 |
+
coord = tuple(tracks[frame_idx, point_idx])
|
543 |
+
color = track_colors[frame_idx, point_idx].astype(int)
|
544 |
+
|
545 |
+
frame_img = draw_circle_on_image(
|
546 |
+
frame_img,
|
547 |
+
center=coord,
|
548 |
+
radius=int(self.line_width * 2),
|
549 |
+
color=color,
|
550 |
+
visible=frame_visibility[point_idx],
|
551 |
+
alpha=opacity,
|
552 |
+
)
|
553 |
+
|
554 |
+
frames[frame_idx] = np.array(frame_img)
|
555 |
+
|
556 |
+
return frames
|