nikkar commited on
Commit
9f03168
·
verified ·
1 Parent(s): 8a1b8f9

Update visualizer.py

Browse files
Files changed (1) hide show
  1. visualizer.py +495 -332
visualizer.py CHANGED
@@ -4,390 +4,553 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
7
  import os
8
- import numpy as np
 
9
  import imageio
 
 
10
  import torch
 
11
 
12
  from matplotlib import cm
13
- import torch.nn.functional as F
14
- import torchvision.transforms as transforms
15
- import matplotlib.pyplot as plt
16
  from PIL import Image, ImageDraw
17
 
18
 
19
- def read_video_from_path(path):
20
- try:
21
- reader = imageio.get_reader(path)
22
- except Exception as e:
23
- print("Error opening video file: ", e)
24
- return None
25
- frames = []
26
- for i, im in enumerate(reader):
27
- frames.append(np.array(im))
28
- return np.stack(frames)
29
-
30
-
31
- def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True, color_alpha=None):
32
- # Create a draw object
33
- draw = ImageDraw.Draw(rgb)
34
- # Calculate the bounding box of the circle
35
- left_up_point = (coord[0] - radius, coord[1] - radius)
36
- right_down_point = (coord[0] + radius, coord[1] + radius)
37
- # Draw the circle
38
- color = tuple(list(color) + [color_alpha if color_alpha is not None else 255])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- draw.ellipse(
41
- [left_up_point, right_down_point],
42
- fill=tuple(color) if visible else None,
43
- outline=tuple(color),
44
- )
45
- return rgb
46
-
47
-
48
- def draw_line(rgb, coord_y, coord_x, color, linewidth):
49
- draw = ImageDraw.Draw(rgb)
50
- draw.line(
51
- (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
52
- fill=tuple(color),
53
- width=linewidth,
54
- )
55
- return rgb
56
 
 
 
57
 
58
- def add_weighted(rgb, alpha, original, beta, gamma):
59
- return (rgb * alpha + original * beta + gamma).astype("uint8")
60
 
 
 
 
 
 
 
 
 
 
61
 
62
- class Visualizer:
63
  def __init__(
64
  self,
65
- save_dir: str = "./results",
66
- grayscale: bool = False,
67
- pad_value: int = 0,
68
  fps: int = 10,
69
- mode: str = "rainbow", # 'cool', 'optical_flow'
70
- linewidth: int = 2,
71
- show_first_frame: int = 10,
72
- tracks_leave_trace: int = 0, # -1 for infinite
73
  ):
74
- self.mode = mode
75
- self.save_dir = save_dir
76
- if mode == "rainbow":
77
- self.color_map = cm.get_cmap("gist_rainbow")
78
- elif mode == "cool":
79
- self.color_map = cm.get_cmap(mode)
80
- self.show_first_frame = show_first_frame
81
- self.grayscale = grayscale
82
- self.tracks_leave_trace = tracks_leave_trace
83
- self.pad_value = pad_value
84
- self.linewidth = linewidth
85
  self.fps = fps
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def visualize(
88
  self,
89
  video: torch.Tensor, # (B,T,C,H,W)
90
  tracks: torch.Tensor, # (B,T,N,2)
91
- visibility: torch.Tensor = None, # (B, T, N, 1) bool
92
- gt_tracks: torch.Tensor = None, # (B,T,N,2)
93
- segm_mask: torch.Tensor = None, # (B,1,H,W)
94
  filename: str = "video",
95
- writer=None, # tensorboard Summary Writer, used for visualization during training
96
- step: int = 0,
97
- query_frame=0,
98
  save_video: bool = True,
99
- compensate_for_camera_motion: bool = False,
100
- opacity: float = 1.0,
101
- ):
102
- if compensate_for_camera_motion:
103
- assert segm_mask is not None
104
- # if segm_mask is not None:
105
- # coords = tracks[0, query_frame].round().long()
106
- # segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
107
-
108
- video = F.pad(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  video,
110
- (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
111
- "constant",
112
- 255,
113
  )
114
- color_alpha = int(opacity * 255)
115
- tracks = tracks + self.pad_value
116
 
117
- if self.grayscale:
118
- transform = transforms.Grayscale()
119
- video = transform(video)
120
- video = video.repeat(1, 1, 3, 1, 1)
121
 
122
- res_video = self.draw_tracks_on_video(
123
- video=video,
124
- tracks=tracks,
 
 
 
 
125
  visibility=visibility,
126
- segm_mask=segm_mask,
127
- gt_tracks=gt_tracks,
128
  query_frame=query_frame,
129
- compensate_for_camera_motion=compensate_for_camera_motion,
130
- color_alpha=color_alpha,
131
  )
 
 
132
  if save_video:
133
- self.save_video(res_video, filename=filename, writer=writer, step=step)
134
- return res_video
135
-
136
- def save_video(self, video, filename, writer=None, step=0):
137
- if writer is not None:
138
- writer.add_video(
139
- filename,
140
- video.to(torch.uint8),
141
- global_step=step,
142
- fps=self.fps,
143
- )
144
- else:
145
- os.makedirs(self.save_dir, exist_ok=True)
146
- wide_list = list(video.unbind(1))
147
- wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
148
 
149
- # Prepare the video file path
150
- save_path = os.path.join(self.save_dir, f"{filename}.mp4")
151
 
152
- # Create a writer object
153
- video_writer = imageio.get_writer(save_path, fps=self.fps)
 
 
 
154
 
155
- # Write frames to the video file
156
- for frame in wide_list[2:-1]:
157
- video_writer.append_data(frame)
 
 
158
 
159
- video_writer.close()
160
 
161
- print(f"Video saved to {save_path}")
 
 
 
 
 
 
 
 
 
162
 
163
  def draw_tracks_on_video(
164
  self,
165
  video: torch.Tensor,
166
  tracks: torch.Tensor,
167
  visibility: torch.Tensor = None,
168
- segm_mask: torch.Tensor = None,
169
- gt_tracks=None,
170
- query_frame=0,
171
- compensate_for_camera_motion=False,
172
- color_alpha: int = 255,
173
- ):
174
- B, T, C, H, W = video.shape
175
- _, _, N, D = tracks.shape
176
-
177
- assert D == 2
178
- assert C == 3
179
- video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
180
- tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
181
- if gt_tracks is not None:
182
- gt_tracks = gt_tracks[0].detach().cpu().numpy()
183
-
184
- res_video = []
185
-
186
- # process input video
187
- for rgb in video:
188
- res_video.append(rgb.copy())
189
- vector_colors = np.zeros((T, N, 3))
190
-
191
- if self.mode == "optical_flow":
192
- import flow_vis
193
-
194
- vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
195
- elif segm_mask is None:
196
- if self.mode == "rainbow":
197
- y_min, y_max = (
198
- tracks[query_frame, :, 1].min(),
199
- tracks[query_frame, :, 1].max(),
200
- )
201
- norm = plt.Normalize(y_min, y_max)
202
- for n in range(N):
203
- if isinstance(query_frame, torch.Tensor):
204
- query_frame_ = query_frame[n]
205
- else:
206
- query_frame_ = query_frame
207
- color = self.color_map(norm(tracks[query_frame_, n, 1]))
208
- color = np.array(color[:3])[None] * 255
209
- vector_colors[:, n] = np.repeat(color, T, axis=0)
210
- else:
211
- # color changes with time
212
- for t in range(T):
213
- color = np.array(self.color_map(t / T)[:3])[None] * 255
214
- vector_colors[t] = np.repeat(color, N, axis=0)
215
- else:
216
- if self.mode == "rainbow":
217
- vector_colors[:, segm_mask <= 0, :] = 255
218
 
219
- y_min, y_max = (
220
- tracks[0, segm_mask > 0, 1].min(),
221
- tracks[0, segm_mask > 0, 1].max(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
- norm = plt.Normalize(y_min, y_max)
224
- for n in range(N):
225
- if segm_mask[n] > 0:
226
- color = self.color_map(norm(tracks[0, n, 1]))
227
- color = np.array(color[:3])[None] * 255
228
- vector_colors[:, n] = np.repeat(color, T, axis=0)
229
-
230
- else:
231
- # color changes with segm class
232
- segm_mask = segm_mask.cpu()
233
- color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
234
- color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
235
- color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
236
- vector_colors = np.repeat(color[None], T, axis=0)
237
-
238
- # draw tracks
239
- if self.tracks_leave_trace != 0:
240
- for t in range(query_frame + 1, T):
241
- first_ind = (
242
- max(0, t - self.tracks_leave_trace)
243
- if self.tracks_leave_trace >= 0
244
- else 0
245
  )
246
- curr_tracks = tracks[first_ind : t + 1]
247
- curr_colors = vector_colors[first_ind : t + 1]
248
- if compensate_for_camera_motion:
249
- diff = (
250
- tracks[first_ind : t + 1, segm_mask <= 0]
251
- - tracks[t : t + 1, segm_mask <= 0]
252
- ).mean(1)[:, None]
253
-
254
- curr_tracks = curr_tracks - diff
255
- curr_tracks = curr_tracks[:, segm_mask > 0]
256
- curr_colors = curr_colors[:, segm_mask > 0]
257
-
258
- res_video[t] = self._draw_pred_tracks(
259
- res_video[t],
260
- curr_tracks,
261
- curr_colors,
262
  )
263
- if gt_tracks is not None:
264
- res_video[t] = self._draw_gt_tracks(
265
- res_video[t], gt_tracks[first_ind : t + 1]
266
- )
267
 
268
- # draw points
269
- for t in range(T):
270
- img = Image.fromarray(np.uint8(res_video[t]))
271
- for i in range(N):
272
- coord = (tracks[t, i, 0], tracks[t, i, 1])
273
- visibile = True
274
- if visibility is not None:
275
- visibile = visibility[0, t, i]
276
- if coord[0] != 0 and coord[1] != 0:
277
- if not compensate_for_camera_motion or (
278
- compensate_for_camera_motion and segm_mask[i] > 0
279
- ):
280
- # img = draw_circle(
281
- # img,
282
- # coord=coord,
283
- # radius=int(self.linewidth * 2),
284
- # color=vector_colors[t, i].astype(int),
285
- # visible=visibile,
286
- # color_alpha=color_alpha,
287
- # )
288
-
289
- # coord_ = coord[t,i]
290
- # draw a red cross
291
- # if gt_tracks[0] > 0 and gt_tracks[1] > 0:
292
- if visibile:
293
- length = self.linewidth * 3
294
- coord_y = (int(coord[0]) + length, int(coord[1]) + length)
295
- coord_x = (int(coord[0]) - length, int(coord[1]) - length)
296
- rgb = draw_line(
297
- img,
298
- coord_y,
299
- coord_x,
300
- vector_colors[t, i].astype(int),
301
- self.linewidth,
302
- )
303
- coord_y = (int(coord[0]) - length, int(coord[1]) + length)
304
- coord_x = (int(coord[0]) + length, int(coord[1]) - length)
305
- rgb = draw_line(
306
- img,
307
- coord_y,
308
- coord_x,
309
- vector_colors[t, i].astype(int),
310
- self.linewidth,
311
- )
312
- res_video[t] = np.array(img)
313
-
314
- # construct the final rgb sequence
315
- if self.show_first_frame > 0:
316
- res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
317
- return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
318
-
319
- def _draw_pred_tracks(
320
  self,
321
- rgb: np.ndarray, # H x W x 3
322
- tracks: np.ndarray, # T x 2
323
- vector_colors: np.ndarray,
324
- alpha: float = 0.5,
325
- ):
326
- T, N, _ = tracks.shape
327
- rgb = Image.fromarray(np.uint8(rgb))
328
- for s in range(T - 1):
329
- vector_color = vector_colors[s]
330
- original = rgb.copy()
331
- alpha = (s / T) ** 2
332
- for i in range(N):
333
- coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
334
- coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
335
- if coord_y[0] != 0 and coord_y[1] != 0:
336
- rgb = draw_line(
337
- rgb,
338
- coord_y,
339
- coord_x,
340
- vector_color[i].astype(int),
341
- self.linewidth,
342
- )
343
- if self.tracks_leave_trace > 0:
344
- rgb = Image.fromarray(
345
- np.uint8(
346
- add_weighted(
347
- np.array(rgb), alpha, np.array(original), 1 - alpha, 0
 
 
 
 
 
 
 
 
 
 
 
 
348
  )
349
- )
350
- )
351
- rgb = np.array(rgb)
352
- return rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- def _draw_gt_tracks(
 
 
 
 
 
 
 
 
 
 
 
355
  self,
356
- rgb: np.ndarray, # H x W x 3,
357
- gt_tracks: np.ndarray, # T x 2
358
- vector_colors: np.ndarray = None,
359
- ):
360
- T, N, _ = gt_tracks.shape
361
- if vector_colors is None:
362
- color = np.array((211, 0, 0))
363
- rgb = Image.fromarray(np.uint8(rgb))
364
- for t in range(T):
365
- if vector_colors is not None:
366
- vector_color = vector_colors[t]
367
- for i in range(N):
368
- if vector_colors is not None:
369
- color = vector_color[i].astype(int)
370
- gt_tracks = gt_tracks[t][i]
371
- # draw a red cross
372
- if gt_tracks[0] > 0 and gt_tracks[1] > 0:
373
- length = self.linewidth * 3
374
- coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
375
- coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
376
- rgb = draw_line(
377
- rgb,
378
- coord_y,
379
- coord_x,
380
- color,
381
- self.linewidth,
 
 
 
 
 
382
  )
383
- coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
384
- coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
385
- rgb = draw_line(
386
- rgb,
387
- coord_y,
388
- coord_x,
389
- color,
390
- self.linewidth,
 
 
 
391
  )
392
- rgb = np.array(rgb)
393
- return rgb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+
8
  import os
9
+ from typing import List
10
+
11
  import imageio
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
  import torch
15
+ import torch.nn.functional as F
16
 
17
  from matplotlib import cm
 
 
 
18
  from PIL import Image, ImageDraw
19
 
20
 
21
+ def draw_circle_on_image(
22
+ image: Image,
23
+ center: tuple,
24
+ radius: int,
25
+ color: tuple = (255, 0, 0),
26
+ visible: bool = True,
27
+ alpha: int = None,
28
+ ) -> Image:
29
+ """Draw a circle on a PIL Image.
30
+
31
+ Args:
32
+ image: PIL Image to draw on
33
+ center: (x,y) coordinates of circle center
34
+ radius: Radius of circle in pixels
35
+ color: RGB color tuple
36
+ visible: Whether to fill the circle
37
+ alpha: Optional alpha value for transparency
38
+
39
+ Returns:
40
+ Modified PIL Image
41
+ """
42
+ draw = ImageDraw.Draw(image, 'RGBA') # Enable alpha channel
43
+ # Use float coordinates for smoother rendering
44
+ bbox = [
45
+ (center[0] - radius, center[1] - radius),
46
+ (center[0] + radius, center[1] + radius),
47
+ ]
48
+ color = tuple(list(color) + [alpha if alpha is not None else 255])
49
+
50
+ # Use anti-aliasing by drawing a slightly larger circle underneath
51
+ if visible:
52
+ # Draw a slightly larger background circle for anti-aliasing
53
+ larger_bbox = [
54
+ (center[0] - radius - 0.5, center[1] - radius - 0.5),
55
+ (center[0] + radius + 0.5, center[1] + radius + 0.5),
56
+ ]
57
+ draw.ellipse(larger_bbox, fill=tuple(list(color[:-1]) + [int(color[-1] * 0.5)]))
58
+
59
+ draw.ellipse(bbox, fill=tuple(color) if visible else None, outline=tuple(color))
60
+ return image
61
+
62
+
63
+ def draw_line_segment(
64
+ image: Image, start: tuple, end: tuple, color: tuple, width: int
65
+ ) -> Image:
66
+ """Draw a line on a PIL Image.
67
+
68
+ Args:
69
+ image: PIL Image to draw on
70
+ start: (x,y) coordinates of line start
71
+ end: (x,y) coordinates of line end
72
+ color: RGB color tuple
73
+ width: Line width in pixels
74
+
75
+ Returns:
76
+ Modified PIL Image
77
+ """
78
+ draw = ImageDraw.Draw(image)
79
+ draw.line((start[0], start[1], end[0], end[1]), fill=tuple(color), width=width)
80
+ return image
81
+
82
+
83
+ def blend_images(
84
+ image1: np.ndarray, alpha: float, image2: np.ndarray, beta: float, gamma: float
85
+ ) -> np.ndarray:
86
+ """Blend two images with weights.
87
+
88
+ Args:
89
+ image1: First image array
90
+ alpha: Weight of first image
91
+ image2: Second image array
92
+ beta: Weight of second image
93
+ gamma: Scalar added to weighted sum
94
+
95
+ Returns:
96
+ Blended uint8 image array
97
+ """
98
+ return (image1 * alpha + image2 * beta + gamma).astype("uint8")
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ class Visualizer:
102
+ """A class for visualizing point tracks on videos.
103
 
104
+ Handles drawing tracked points and their trajectories on video frames.
 
105
 
106
+ Args:
107
+ output_dir: Directory to save output visualizations
108
+ padding: Padding to add around video frames in pixels
109
+ fps: Frames per second for output video
110
+ colormap: Color scheme for tracks ('rainbow' or 'spring')
111
+ line_width: Width of track lines in pixels
112
+ initial_frame_repeat: Number of times to repeat first frame
113
+ track_history_length: How many past frames to show tracks for (0=current only, -1=all)
114
+ """
115
 
 
116
  def __init__(
117
  self,
118
+ output_dir: str = "./results",
119
+ padding: int = 0,
 
120
  fps: int = 10,
121
+ colormap: str = "rainbow",
122
+ line_width: int = 2,
123
+ initial_frame_repeat: int = 10,
124
+ track_history_length: int = 0,
125
  ):
126
+ self.output_dir = output_dir
127
+ self.padding = padding
 
 
 
 
 
 
 
 
 
128
  self.fps = fps
129
+ self.line_width = line_width
130
+ self.initial_frame_repeat = initial_frame_repeat
131
+ self.track_history_length = track_history_length
132
+
133
+ # Set up colormap for track visualization
134
+ self.colormap = colormap
135
+ if colormap not in ["rainbow", "spring"]:
136
+ raise ValueError("Colormap must be 'rainbow' or 'spring'")
137
+
138
+ self.color_mapper = cm.get_cmap(
139
+ "gist_rainbow" if colormap == "rainbow" else "spring"
140
+ )
141
 
142
  def visualize(
143
  self,
144
  video: torch.Tensor, # (B,T,C,H,W)
145
  tracks: torch.Tensor, # (B,T,N,2)
146
+ visibility: torch.Tensor = None, # (B,T,N,1) bool
147
+ segmentation: torch.Tensor = None, # (B,1,H,W)
 
148
  filename: str = "video",
149
+ query_frame: int = 0,
 
 
150
  save_video: bool = True,
151
+ point_opacity: float = 1.0,
152
+ ) -> torch.Tensor:
153
+ """Visualize tracked points and their trajectories on video frames.
154
+
155
+ Args:
156
+ video: Input video tensor of shape (B,T,C,H,W)
157
+ tracks: Point track coordinates of shape (B,T,N,2)
158
+ visibility: Optional visibility mask of shape (B,T,N,1)
159
+ segmentation: Optional segmentation mask of shape (B,1,H,W)
160
+ filename: Output filename for saved video
161
+ query_frame: Frame index to use for color assignment
162
+ save_video: Whether to save visualization video
163
+ point_opacity: Opacity value for track points (0-1)
164
+
165
+ Returns:
166
+ Tensor containing visualization frames
167
+ """
168
+ # Process segmentation if provided
169
+ if segmentation is not None:
170
+ coords = tracks[0, query_frame].round().long()
171
+ segmentation = segmentation[0, query_frame][
172
+ coords[:, 1], coords[:, 0]
173
+ ].long()
174
+
175
+ # Add padding to video frames
176
+ padded_video = F.pad(
177
  video,
178
+ (self.padding, self.padding, self.padding, self.padding),
179
+ mode="constant",
180
+ value=255,
181
  )
 
 
182
 
183
+ # Convert opacity to integer value
184
+ opacity_value = min(max(int(point_opacity * 255), 0), 255)
 
 
185
 
186
+ # Adjust track coordinates for padding
187
+ padded_tracks = tracks + self.padding
188
+
189
+ # Generate visualization frames
190
+ output_video = self.draw_tracks_on_video(
191
+ video=padded_video,
192
+ tracks=padded_tracks,
193
  visibility=visibility,
194
+ segmentation=segmentation,
 
195
  query_frame=query_frame,
196
+ opacity=opacity_value,
 
197
  )
198
+
199
+ # Save video if requested
200
  if save_video:
201
+ self.save_video(output_video, filename=filename)
202
+
203
+ return output_video
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ def save_video(self, video: torch.Tensor, filename: str):
206
+ """Save video tensor as MP4 file.
207
 
208
+ Args:
209
+ video: Video tensor of shape (B,T,C,H,W)
210
+ filename: Output filename without extension
211
+ """
212
+ os.makedirs(self.output_dir, exist_ok=True)
213
 
214
+ # Extract frames from video tensor
215
+ frames = [
216
+ frame[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
217
+ for frame in video.unbind(1)
218
+ ]
219
 
220
+ output_path = os.path.join(self.output_dir, f"{filename}.mp4")
221
 
222
+ try:
223
+ with imageio.get_writer(output_path, fps=self.fps, quality=8) as writer:
224
+ # Write frames excluding padding frames
225
+ for frame in frames[2:-1]:
226
+ writer.append_data(frame)
227
+
228
+ print(f"Successfully saved video to {output_path}")
229
+
230
+ except Exception as e:
231
+ print(f"Error saving video to {output_path}: {str(e)}")
232
 
233
  def draw_tracks_on_video(
234
  self,
235
  video: torch.Tensor,
236
  tracks: torch.Tensor,
237
  visibility: torch.Tensor = None,
238
+ segmentation: torch.Tensor = None,
239
+ query_frame: int = 0,
240
+ opacity: int = 255,
241
+ ) -> torch.Tensor:
242
+ """Draw tracks on video frames.
243
+
244
+ Args:
245
+ video: Video tensor of shape (B,T,C,H,W)
246
+ tracks: Track coordinates tensor of shape (B,T,N,2)
247
+ visibility: Optional visibility mask of shape (B,T,N)
248
+ segmentation: Optional segmentation mask for coloring
249
+ query_frame: Frame index to use for rainbow coloring
250
+ opacity: Opacity value for track points (0-255)
251
+
252
+ Returns:
253
+ Video tensor with visualized tracks of shape (1,T,3,H,W)
254
+ """
255
+ # Validate input dimensions
256
+ _, num_frames, channels, _, _ = video.shape
257
+ _, _, num_points, dims = tracks.shape
258
+ assert dims == 2 and channels == 3, "Invalid input dimensions"
259
+
260
+ # Convert tensors to numpy arrays but keep as float
261
+ video_np = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()
262
+ tracks_np = tracks[0].detach().cpu().numpy()
263
+
264
+ # Create output frame buffer
265
+ output_frames = [frame.copy() for frame in video_np]
266
+
267
+ # Assign colors to tracks based on segmentation or position
268
+ track_colors = (
269
+ self._assign_segmentation_colors(
270
+ tracks_np, segmentation, num_frames, num_points
271
+ )
272
+ if segmentation is not None
273
+ else self._assign_track_colors(
274
+ tracks_np, query_frame, num_frames, num_points
275
+ )
276
+ )
277
+
278
+ # Draw track history lines if enabled
279
+ if self.track_history_length != 0:
280
+ output_frames = self._draw_track_lines(
281
+ output_frames, tracks_np, track_colors, query_frame, num_frames
282
+ )
283
+
284
+ # Draw track points with visibility and opacity
285
+ output_frames = self._draw_track_points(
286
+ output_frames, tracks_np, track_colors, visibility, opacity
287
+ )
288
 
289
+ # Add initial frame repeats for better visualization
290
+ if self.initial_frame_repeat > 0:
291
+ output_frames = [
292
+ output_frames[0]
293
+ ] * self.initial_frame_repeat + output_frames[1:]
294
+
295
+ # Convert back to torch tensor
296
+ return (
297
+ torch.from_numpy(np.stack(output_frames)).permute(0, 3, 1, 2)[None].byte()
298
+ )
299
+
300
+ def _assign_track_colors(
301
+ self, tracks: np.ndarray, query_frame: int, num_frames: int, num_points: int
302
+ ) -> np.ndarray:
303
+ """Assigns colors to tracks based on either rainbow mapping of y-coordinates or time-based coloring.
304
+
305
+ Args:
306
+ tracks: Track coordinates array of shape (num_frames, num_points, 2)
307
+ query_frame: Frame index to use for rainbow coloring
308
+ num_frames: Total number of frames
309
+ num_points: Number of tracked points
310
+
311
+ Returns:
312
+ Array of track colors with shape (num_frames, num_points, 3)
313
+ """
314
+ track_colors = np.zeros((num_frames, num_points, 3))
315
+
316
+ if self.colormap == "rainbow":
317
+ # Normalize y-coordinates to [0,1] range for rainbow coloring
318
+ y_coords = tracks[query_frame, :, 1]
319
+ y_min, y_max = y_coords.min(), y_coords.max()
320
+ if y_min == y_max:
321
+ y_max = y_min + 1 # Avoid division by zero
322
+ norm = plt.Normalize(y_min, y_max)
323
+
324
+ # Assign colors based on normalized y-coordinate
325
+ for point_idx in range(num_points):
326
+ query_idx = (
327
+ query_frame[point_idx]
328
+ if isinstance(query_frame, torch.Tensor)
329
+ else query_frame
330
  )
331
+ color = (
332
+ np.array(
333
+ self.color_mapper(norm(tracks[query_idx, point_idx, 1]))[:3]
334
+ )[None]
335
+ * 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  )
337
+ track_colors[:, point_idx] = np.repeat(color, num_frames, axis=0)
338
+ else:
339
+ # Assign colors that vary smoothly with time
340
+ for frame_idx in range(num_frames):
341
+ color = (
342
+ np.array(self.color_mapper(frame_idx / max(1, num_frames - 1))[:3])[
343
+ None
344
+ ]
345
+ * 255
 
 
 
 
 
 
 
346
  )
347
+ track_colors[frame_idx] = np.repeat(color, num_points, axis=0)
348
+
349
+ return track_colors.astype(np.uint8)
 
350
 
351
+ def _assign_segmentation_colors(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  self,
353
+ tracks: np.ndarray,
354
+ segmentation: torch.Tensor,
355
+ num_frames: int,
356
+ num_points: int,
357
+ ) -> np.ndarray:
358
+ """Assigns colors to tracks based on segmentation masks and colormap.
359
+
360
+ Args:
361
+ tracks: Track coordinates array of shape (num_frames, num_points, 2)
362
+ segmentation: Binary segmentation mask of shape (num_points,)
363
+ num_frames: Total number of frames
364
+ num_points: Number of tracked points
365
+
366
+ Returns:
367
+ Array of track colors with shape (num_frames, num_points, 3)
368
+ """
369
+ track_colors = np.zeros((num_frames, num_points, 3))
370
+
371
+ if self.colormap == "rainbow":
372
+ # Set background points to white
373
+ background_mask = segmentation <= 0
374
+ track_colors[:, background_mask, :] = 255
375
+
376
+ # Color foreground points based on y-coordinate
377
+ foreground_mask = segmentation > 0
378
+ if torch.any(foreground_mask):
379
+ y_coords = tracks[0, foreground_mask, 1]
380
+ y_min, y_max = y_coords.min(), y_coords.max()
381
+ if y_min == y_max:
382
+ y_max = y_min + 1 # Avoid division by zero
383
+ norm = plt.Normalize(y_min, y_max)
384
+
385
+ for point_idx in range(num_points):
386
+ if segmentation[point_idx] > 0:
387
+ color = (
388
+ np.array(
389
+ self.color_mapper(norm(tracks[0, point_idx, 1]))[:3]
390
+ )[None]
391
+ * 255
392
  )
393
+ track_colors[:, point_idx] = np.repeat(
394
+ color, num_frames, axis=0
395
+ )
396
+ else:
397
+ # Binary coloring based on segmentation
398
+ segmentation = segmentation.cpu()
399
+ colors = np.zeros((num_points, 3), dtype=np.float32)
400
+ colors[segmentation > 0] = (
401
+ np.array(self.color_mapper(1.0)[:3]) * 255.0
402
+ ) # Foreground
403
+ colors[segmentation <= 0] = (
404
+ np.array(self.color_mapper(0.0)[:3]) * 255.0
405
+ ) # Background
406
+ track_colors = np.repeat(colors[None], num_frames, axis=0)
407
+
408
+ return track_colors.astype(np.uint8)
409
+
410
+ def _draw_track_lines(
411
+ self,
412
+ frames: List[np.ndarray],
413
+ tracks: np.ndarray,
414
+ track_colors: np.ndarray,
415
+ query_frame: int,
416
+ num_frames: int,
417
+ ) -> List[np.ndarray]:
418
+ """Draw track lines showing point trajectories over time.
419
+
420
+ Args:
421
+ frames: List of video frames to draw on
422
+ tracks: Array of track coordinates (num_frames, num_points, 2)
423
+ track_colors: Array of track colors (num_frames, num_points, 3)
424
+ query_frame: Frame index where tracking starts
425
+ num_frames: Total number of frames
426
+
427
+ Returns:
428
+ List of frames with track lines drawn
429
+ """
430
+ # Draw tracks starting from query frame
431
+ for frame_idx in range(query_frame + 1, num_frames):
432
+ # Get track history based on history length setting
433
+ start_idx = (
434
+ max(0, frame_idx - self.track_history_length)
435
+ if self.track_history_length >= 0
436
+ else 0
437
+ )
438
 
439
+ # Extract relevant track segments and colors
440
+ curr_tracks = tracks[start_idx : frame_idx + 1]
441
+ curr_colors = track_colors[start_idx : frame_idx + 1]
442
+
443
+ # Draw track segments on current frame
444
+ frames[frame_idx] = self._draw_track_segments(
445
+ frames[frame_idx], curr_tracks, curr_colors
446
+ )
447
+
448
+ return frames
449
+
450
+ def _draw_track_segments(
451
  self,
452
+ frame: np.ndarray,
453
+ tracks: np.ndarray,
454
+ colors: np.ndarray,
455
+ ) -> np.ndarray:
456
+ """Draw track segments showing point trajectories between consecutive frames.
457
+
458
+ Args:
459
+ frame: Video frame to draw on
460
+ tracks: Array of track coordinates (num_segments, num_points, 2)
461
+ colors: Array of track colors (num_segments, num_points, 3)
462
+
463
+ Returns:
464
+ Frame with track segments drawn
465
+ """
466
+ num_segments, num_points, _ = tracks.shape
467
+ frame_img = Image.fromarray(np.uint8(frame))
468
+
469
+ for segment_idx in range(num_segments - 1):
470
+ segment_color = colors[segment_idx]
471
+ original = frame_img.copy()
472
+
473
+ # Use cubic falloff for track history opacity
474
+ alpha = (segment_idx / num_segments) ** 3
475
+
476
+ valid_points = ~np.isclose(tracks[segment_idx], 0).all(axis=1)
477
+
478
+ for point_idx in range(num_points):
479
+ if valid_points[point_idx]:
480
+ start = (
481
+ tracks[segment_idx, point_idx, 0],
482
+ tracks[segment_idx, point_idx, 1],
483
  )
484
+ end = (
485
+ tracks[segment_idx + 1, point_idx, 0],
486
+ tracks[segment_idx + 1, point_idx, 1],
487
+ )
488
+
489
+ frame_img = draw_line_segment(
490
+ frame_img,
491
+ start,
492
+ end,
493
+ segment_color[point_idx].astype(int),
494
+ self.line_width,
495
  )
496
+
497
+ if self.track_history_length > 0:
498
+ frame_img = Image.fromarray(
499
+ blend_images(
500
+ np.array(frame_img), alpha, np.array(original), 1 - alpha, 0
501
+ )
502
+ )
503
+
504
+ return np.array(frame_img)
505
+
506
+ def _draw_track_points(
507
+ self,
508
+ frames: List[np.ndarray],
509
+ tracks: np.ndarray,
510
+ track_colors: np.ndarray,
511
+ visibility: torch.Tensor,
512
+ opacity: int,
513
+ ) -> List[np.ndarray]:
514
+ """Draw tracked points on each frame with circles.
515
+
516
+ Args:
517
+ frames: List of video frames to draw on
518
+ tracks: Array of track coordinates (num_frames, num_points, 2)
519
+ track_colors: Array of track colors (num_frames, num_points, 3)
520
+ visibility: Tensor indicating point visibility per frame
521
+ opacity: Opacity value for drawing points
522
+
523
+ Returns:
524
+ List of frames with track points drawn
525
+ """
526
+ frame_imgs = [Image.fromarray(np.uint8(frame)) for frame in frames]
527
+
528
+ # Use more precise validation of points
529
+ valid_points = ~np.isclose(tracks, 0).all(axis=2)
530
+
531
+ for frame_idx, frame_img in enumerate(frame_imgs):
532
+ frame_visibility = (
533
+ np.ones(tracks.shape[1], dtype=bool)
534
+ if visibility is None
535
+ else visibility[0, frame_idx].cpu().numpy()
536
+ )
537
+
538
+ points_to_draw = np.logical_and(valid_points[frame_idx], frame_visibility)
539
+
540
+ for point_idx in np.where(points_to_draw)[0]:
541
+ # Keep coordinates as floats
542
+ coord = tuple(tracks[frame_idx, point_idx])
543
+ color = track_colors[frame_idx, point_idx].astype(int)
544
+
545
+ frame_img = draw_circle_on_image(
546
+ frame_img,
547
+ center=coord,
548
+ radius=int(self.line_width * 2),
549
+ color=color,
550
+ visible=frame_visibility[point_idx],
551
+ alpha=opacity,
552
+ )
553
+
554
+ frames[frame_idx] = np.array(frame_img)
555
+
556
+ return frames