nikkar commited on
Commit
d0b6f98
·
verified ·
1 Parent(s): 03dc409

Update visualizer.py

Browse files
Files changed (1) hide show
  1. visualizer.py +326 -491
visualizer.py CHANGED
@@ -4,555 +4,390 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
-
8
  import os
9
- from typing import List
10
-
11
- import imageio
12
- import matplotlib.pyplot as plt
13
  import numpy as np
 
14
  import torch
15
- import torch.nn.functional as F
16
 
17
  from matplotlib import cm
 
 
 
18
  from PIL import Image, ImageDraw
19
 
20
 
21
- def draw_circle_on_image(
22
- image: Image,
23
- center: tuple,
24
- radius: int,
25
- color: tuple = (255, 0, 0),
26
- visible: bool = True,
27
- alpha: int = None,
28
- ) -> Image:
29
- """Draw a circle on a PIL Image.
30
-
31
- Args:
32
- image: PIL Image to draw on
33
- center: (x,y) coordinates of circle center
34
- radius: Radius of circle in pixels
35
- color: RGB color tuple
36
- visible: Whether to fill the circle
37
- alpha: Optional alpha value for transparency
38
-
39
- Returns:
40
- Modified PIL Image
41
- """
42
- draw = ImageDraw.Draw(image, 'RGBA') # Enable alpha channel
43
- # Use float coordinates for smoother rendering
44
- bbox = [
45
- (center[0] - radius, center[1] - radius),
46
- (center[0] + radius, center[1] + radius),
47
- ]
48
- color = tuple(list(color) + [alpha if alpha is not None else 255])
49
-
50
- # Use anti-aliasing by drawing a slightly larger circle underneath
51
- if visible:
52
- # Draw a slightly larger background circle for anti-aliasing
53
- larger_bbox = [
54
- (center[0] - radius - 0.5, center[1] - radius - 0.5),
55
- (center[0] + radius + 0.5, center[1] + radius + 0.5),
56
- ]
57
- draw.ellipse(larger_bbox, fill=tuple(list(color[:-1]) + [int(color[-1] * 0.5)]))
58
-
59
- draw.ellipse(bbox, fill=tuple(color) if visible else None, outline=tuple(color))
60
- return image
61
-
62
-
63
- def draw_line_segment(
64
- image: Image, start: tuple, end: tuple, color: tuple, width: int
65
- ) -> Image:
66
- """Draw a line on a PIL Image.
67
-
68
- Args:
69
- image: PIL Image to draw on
70
- start: (x,y) coordinates of line start
71
- end: (x,y) coordinates of line end
72
- color: RGB color tuple
73
- width: Line width in pixels
74
-
75
- Returns:
76
- Modified PIL Image
77
- """
78
- draw = ImageDraw.Draw(image)
79
- draw.line((start[0], start[1], end[0], end[1]), fill=tuple(color), width=width)
80
- return image
81
-
82
-
83
- def blend_images(
84
- image1: np.ndarray, alpha: float, image2: np.ndarray, beta: float, gamma: float
85
- ) -> np.ndarray:
86
- """Blend two images with weights.
87
-
88
- Args:
89
- image1: First image array
90
- alpha: Weight of first image
91
- image2: Second image array
92
- beta: Weight of second image
93
- gamma: Scalar added to weighted sum
94
-
95
- Returns:
96
- Blended uint8 image array
97
- """
98
- return (image1 * alpha + image2 * beta + gamma).astype("uint8")
99
 
100
 
101
- class Visualizer:
102
- """A class for visualizing point tracks on videos.
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- Handles drawing tracked points and their trajectories on video frames.
105
 
106
- Args:
107
- save_dir: Directory to save output visualizations
108
- padding: Padding to add around video frames in pixels
109
- fps: Frames per second for output video
110
- colormap: Color scheme for tracks ('rainbow' or 'spring')
111
- line_width: Width of track lines in pixels
112
- initial_frame_repeat: Number of times to repeat first frame
113
- track_history_length: How many past frames to show tracks for (0=current only, -1=all)
114
- """
115
 
 
 
 
 
 
 
116
  def __init__(
117
  self,
118
  save_dir: str = "./results",
119
  grayscale: bool = False,
120
  pad_value: int = 0,
121
  fps: int = 10,
122
- colormap: str = "rainbow",
123
  linewidth: int = 2,
124
  show_first_frame: int = 10,
125
- tracks_leave_trace: int = 0,
126
  ):
 
127
  self.save_dir = save_dir
128
- self.padding = pad_value
 
 
 
 
 
 
 
 
129
  self.fps = fps
130
- self.line_width = linewidth
131
- self.initial_frame_repeat = show_first_frame
132
- self.track_history_length = tracks_leave_trace
133
-
134
- # Set up colormap for track visualization
135
- self.colormap = colormap
136
- if colormap not in ["rainbow", "spring"]:
137
- raise ValueError("Colormap must be 'rainbow' or 'spring'")
138
-
139
- self.color_mapper = cm.get_cmap(
140
- "gist_rainbow" if colormap == "rainbow" else "spring"
141
- )
142
 
143
  def visualize(
144
  self,
145
  video: torch.Tensor, # (B,T,C,H,W)
146
  tracks: torch.Tensor, # (B,T,N,2)
147
-
148
- visibility: torch.Tensor = None, # (B,T,N,1) bool
149
- segmentation: torch.Tensor = None, # (B,1,H,W)
150
  filename: str = "video",
151
- query_frame: int = 0,
 
 
152
  save_video: bool = True,
153
- point_opacity: float = 1.0,
154
- ) -> torch.Tensor:
155
- """Visualize tracked points and their trajectories on video frames.
156
-
157
- Args:
158
- video: Input video tensor of shape (B,T,C,H,W)
159
- tracks: Point track coordinates of shape (B,T,N,2)
160
- visibility: Optional visibility mask of shape (B,T,N,1)
161
- segmentation: Optional segmentation mask of shape (B,1,H,W)
162
- filename: Output filename for saved video
163
- query_frame: Frame index to use for color assignment
164
- save_video: Whether to save visualization video
165
- point_opacity: Opacity value for track points (0-1)
166
-
167
- Returns:
168
- Tensor containing visualization frames
169
- """
170
- # Process segmentation if provided
171
- if segmentation is not None:
172
- coords = tracks[0, query_frame].round().long()
173
- segmentation = segmentation[0, query_frame][
174
- coords[:, 1], coords[:, 0]
175
- ].long()
176
-
177
- # Add padding to video frames
178
- padded_video = F.pad(
179
  video,
180
- (self.padding, self.padding, self.padding, self.padding),
181
- mode="constant",
182
- value=255,
183
  )
 
 
184
 
185
- # Convert opacity to integer value
186
- opacity_value = min(max(int(point_opacity * 255), 0), 255)
 
 
187
 
188
- # Adjust track coordinates for padding
189
- padded_tracks = tracks + self.padding
190
-
191
- # Generate visualization frames
192
- output_video = self.draw_tracks_on_video(
193
- video=padded_video,
194
- tracks=padded_tracks,
195
  visibility=visibility,
196
- segmentation=segmentation,
 
197
  query_frame=query_frame,
198
- opacity=opacity_value,
 
199
  )
200
-
201
- # Save video if requested
202
  if save_video:
203
- self.save_video(output_video, filename=filename)
204
-
205
- return output_video
206
-
207
- def save_video(self, video: torch.Tensor, filename: str):
208
- """Save video tensor as MP4 file.
209
-
210
- Args:
211
- video: Video tensor of shape (B,T,C,H,W)
212
- filename: Output filename without extension
213
- """
214
- os.makedirs(self.save_dir, exist_ok=True)
 
 
 
215
 
216
- # Extract frames from video tensor
217
- frames = [
218
- frame[0].permute(1, 2, 0).cpu().numpy().astype(np.uint8)
219
- for frame in video.unbind(1)
220
- ]
221
 
222
- output_path = os.path.join(self.save_dir, f"{filename}.mp4")
 
223
 
224
- try:
225
- with imageio.get_writer(output_path, fps=self.fps, quality=8) as writer:
226
- # Write frames excluding padding frames
227
- for frame in frames[2:-1]:
228
- writer.append_data(frame)
229
 
230
- print(f"Successfully saved video to {output_path}")
231
 
232
- except Exception as e:
233
- print(f"Error saving video to {output_path}: {str(e)}")
234
 
235
  def draw_tracks_on_video(
236
  self,
237
  video: torch.Tensor,
238
  tracks: torch.Tensor,
239
  visibility: torch.Tensor = None,
240
- segmentation: torch.Tensor = None,
241
- query_frame: int = 0,
242
- opacity: int = 255,
243
- ) -> torch.Tensor:
244
- """Draw tracks on video frames.
245
-
246
- Args:
247
- video: Video tensor of shape (B,T,C,H,W)
248
- tracks: Track coordinates tensor of shape (B,T,N,2)
249
- visibility: Optional visibility mask of shape (B,T,N)
250
- segmentation: Optional segmentation mask for coloring
251
- query_frame: Frame index to use for rainbow coloring
252
- opacity: Opacity value for track points (0-255)
253
-
254
- Returns:
255
- Video tensor with visualized tracks of shape (1,T,3,H,W)
256
- """
257
- # Validate input dimensions
258
- _, num_frames, channels, _, _ = video.shape
259
- _, _, num_points, dims = tracks.shape
260
- assert dims == 2 and channels == 3, "Invalid input dimensions"
261
-
262
- # Convert tensors to numpy arrays but keep as float
263
- video_np = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()
264
- tracks_np = tracks[0].detach().cpu().numpy()
265
-
266
- # Create output frame buffer
267
- output_frames = [frame.copy() for frame in video_np]
268
-
269
- # Assign colors to tracks based on segmentation or position
270
- track_colors = (
271
- self._assign_segmentation_colors(
272
- tracks_np, segmentation, num_frames, num_points
273
- )
274
- if segmentation is not None
275
- else self._assign_track_colors(
276
- tracks_np, query_frame, num_frames, num_points
277
- )
278
- )
279
-
280
- # Draw track history lines if enabled
281
- if self.track_history_length != 0:
282
- output_frames = self._draw_track_lines(
283
- output_frames, tracks_np, track_colors, query_frame, num_frames
284
- )
285
-
286
- # Draw track points with visibility and opacity
287
- output_frames = self._draw_track_points(
288
- output_frames, tracks_np, track_colors, visibility, opacity
289
- )
290
-
291
- # Add initial frame repeats for better visualization
292
- if self.initial_frame_repeat > 0:
293
- output_frames = [
294
- output_frames[0]
295
- ] * self.initial_frame_repeat + output_frames[1:]
296
-
297
- # Convert back to torch tensor
298
- return (
299
- torch.from_numpy(np.stack(output_frames)).permute(0, 3, 1, 2)[None].byte()
300
- )
301
-
302
- def _assign_track_colors(
303
- self, tracks: np.ndarray, query_frame: int, num_frames: int, num_points: int
304
- ) -> np.ndarray:
305
- """Assigns colors to tracks based on either rainbow mapping of y-coordinates or time-based coloring.
306
-
307
- Args:
308
- tracks: Track coordinates array of shape (num_frames, num_points, 2)
309
- query_frame: Frame index to use for rainbow coloring
310
- num_frames: Total number of frames
311
- num_points: Number of tracked points
312
-
313
- Returns:
314
- Array of track colors with shape (num_frames, num_points, 3)
315
- """
316
- track_colors = np.zeros((num_frames, num_points, 3))
317
-
318
- if self.colormap == "rainbow":
319
- # Normalize y-coordinates to [0,1] range for rainbow coloring
320
- y_coords = tracks[query_frame, :, 1]
321
- y_min, y_max = y_coords.min(), y_coords.max()
322
- if y_min == y_max:
323
- y_max = y_min + 1 # Avoid division by zero
324
- norm = plt.Normalize(y_min, y_max)
325
-
326
- # Assign colors based on normalized y-coordinate
327
- for point_idx in range(num_points):
328
- query_idx = (
329
- query_frame[point_idx]
330
- if isinstance(query_frame, torch.Tensor)
331
- else query_frame
332
- )
333
- color = (
334
- np.array(
335
- self.color_mapper(norm(tracks[query_idx, point_idx, 1]))[:3]
336
- )[None]
337
- * 255
338
- )
339
- track_colors[:, point_idx] = np.repeat(color, num_frames, axis=0)
340
- else:
341
- # Assign colors that vary smoothly with time
342
- for frame_idx in range(num_frames):
343
- color = (
344
- np.array(self.color_mapper(frame_idx / max(1, num_frames - 1))[:3])[
345
- None
346
- ]
347
- * 255
348
  )
349
- track_colors[frame_idx] = np.repeat(color, num_points, axis=0)
350
-
351
- return track_colors.astype(np.uint8)
352
-
353
- def _assign_segmentation_colors(
354
- self,
355
- tracks: np.ndarray,
356
- segmentation: torch.Tensor,
357
- num_frames: int,
358
- num_points: int,
359
- ) -> np.ndarray:
360
- """Assigns colors to tracks based on segmentation masks and colormap.
361
-
362
- Args:
363
- tracks: Track coordinates array of shape (num_frames, num_points, 2)
364
- segmentation: Binary segmentation mask of shape (num_points,)
365
- num_frames: Total number of frames
366
- num_points: Number of tracked points
367
-
368
- Returns:
369
- Array of track colors with shape (num_frames, num_points, 3)
370
- """
371
- track_colors = np.zeros((num_frames, num_points, 3))
372
-
373
- if self.colormap == "rainbow":
374
- # Set background points to white
375
- background_mask = segmentation <= 0
376
- track_colors[:, background_mask, :] = 255
377
-
378
- # Color foreground points based on y-coordinate
379
- foreground_mask = segmentation > 0
380
- if torch.any(foreground_mask):
381
- y_coords = tracks[0, foreground_mask, 1]
382
- y_min, y_max = y_coords.min(), y_coords.max()
383
- if y_min == y_max:
384
- y_max = y_min + 1 # Avoid division by zero
385
  norm = plt.Normalize(y_min, y_max)
386
-
387
- for point_idx in range(num_points):
388
- if segmentation[point_idx] > 0:
389
- color = (
390
- np.array(
391
- self.color_mapper(norm(tracks[0, point_idx, 1]))[:3]
392
- )[None]
393
- * 255
394
- )
395
- track_colors[:, point_idx] = np.repeat(
396
- color, num_frames, axis=0
397
- )
 
398
  else:
399
- # Binary coloring based on segmentation
400
- segmentation = segmentation.cpu()
401
- colors = np.zeros((num_points, 3), dtype=np.float32)
402
- colors[segmentation > 0] = (
403
- np.array(self.color_mapper(1.0)[:3]) * 255.0
404
- ) # Foreground
405
- colors[segmentation <= 0] = (
406
- np.array(self.color_mapper(0.0)[:3]) * 255.0
407
- ) # Background
408
- track_colors = np.repeat(colors[None], num_frames, axis=0)
409
-
410
- return track_colors.astype(np.uint8)
411
-
412
- def _draw_track_lines(
413
- self,
414
- frames: List[np.ndarray],
415
- tracks: np.ndarray,
416
- track_colors: np.ndarray,
417
- query_frame: int,
418
- num_frames: int,
419
- ) -> List[np.ndarray]:
420
- """Draw track lines showing point trajectories over time.
421
-
422
- Args:
423
- frames: List of video frames to draw on
424
- tracks: Array of track coordinates (num_frames, num_points, 2)
425
- track_colors: Array of track colors (num_frames, num_points, 3)
426
- query_frame: Frame index where tracking starts
427
- num_frames: Total number of frames
428
-
429
- Returns:
430
- List of frames with track lines drawn
431
- """
432
- # Draw tracks starting from query frame
433
- for frame_idx in range(query_frame + 1, num_frames):
434
- # Get track history based on history length setting
435
- start_idx = (
436
- max(0, frame_idx - self.track_history_length)
437
- if self.track_history_length >= 0
438
- else 0
439
- )
440
-
441
- # Extract relevant track segments and colors
442
- curr_tracks = tracks[start_idx : frame_idx + 1]
443
- curr_colors = track_colors[start_idx : frame_idx + 1]
444
-
445
- # Draw track segments on current frame
446
- frames[frame_idx] = self._draw_track_segments(
447
- frames[frame_idx], curr_tracks, curr_colors
448
- )
449
 
450
- return frames
451
-
452
- def _draw_track_segments(
453
- self,
454
- frame: np.ndarray,
455
- tracks: np.ndarray,
456
- colors: np.ndarray,
457
- ) -> np.ndarray:
458
- """Draw track segments showing point trajectories between consecutive frames.
459
-
460
- Args:
461
- frame: Video frame to draw on
462
- tracks: Array of track coordinates (num_segments, num_points, 2)
463
- colors: Array of track colors (num_segments, num_points, 3)
464
-
465
- Returns:
466
- Frame with track segments drawn
467
- """
468
- num_segments, num_points, _ = tracks.shape
469
- frame_img = Image.fromarray(np.uint8(frame))
470
-
471
- for segment_idx in range(num_segments - 1):
472
- segment_color = colors[segment_idx]
473
- original = frame_img.copy()
474
-
475
- # Use cubic falloff for track history opacity
476
- alpha = (segment_idx / num_segments) ** 3
477
-
478
- valid_points = ~np.isclose(tracks[segment_idx], 0).all(axis=1)
479
-
480
- for point_idx in range(num_points):
481
- if valid_points[point_idx]:
482
- start = (
483
- tracks[segment_idx, point_idx, 0],
484
- tracks[segment_idx, point_idx, 1],
485
- )
486
- end = (
487
- tracks[segment_idx + 1, point_idx, 0],
488
- tracks[segment_idx + 1, point_idx, 1],
 
 
 
 
 
 
 
 
489
  )
490
 
491
- frame_img = draw_line_segment(
492
- frame_img,
493
- start,
494
- end,
495
- segment_color[point_idx].astype(int),
496
- self.line_width,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  )
498
-
499
- if self.track_history_length > 0:
500
- frame_img = Image.fromarray(
501
- blend_images(
502
- np.array(frame_img), alpha, np.array(original), 1 - alpha, 0
 
503
  )
504
  )
 
 
505
 
506
- return np.array(frame_img)
507
-
508
- def _draw_track_points(
509
  self,
510
- frames: List[np.ndarray],
511
- tracks: np.ndarray,
512
- track_colors: np.ndarray,
513
- visibility: torch.Tensor,
514
- opacity: int,
515
- ) -> List[np.ndarray]:
516
- """Draw tracked points on each frame with circles.
517
-
518
- Args:
519
- frames: List of video frames to draw on
520
- tracks: Array of track coordinates (num_frames, num_points, 2)
521
- track_colors: Array of track colors (num_frames, num_points, 3)
522
- visibility: Tensor indicating point visibility per frame
523
- opacity: Opacity value for drawing points
524
-
525
- Returns:
526
- List of frames with track points drawn
527
- """
528
- frame_imgs = [Image.fromarray(np.uint8(frame)) for frame in frames]
529
-
530
- # Use more precise validation of points
531
- valid_points = ~np.isclose(tracks, 0).all(axis=2)
532
-
533
- for frame_idx, frame_img in enumerate(frame_imgs):
534
- frame_visibility = (
535
- np.ones(tracks.shape[1], dtype=bool)
536
- if visibility is None
537
- else visibility[0, frame_idx].cpu().numpy()
538
- )
539
-
540
- points_to_draw = np.logical_and(valid_points[frame_idx], frame_visibility)
541
-
542
- for point_idx in np.where(points_to_draw)[0]:
543
- # Keep coordinates as floats
544
- coord = tuple(tracks[frame_idx, point_idx])
545
- color = track_colors[frame_idx, point_idx].astype(int)
546
-
547
- frame_img = draw_circle_on_image(
548
- frame_img,
549
- center=coord,
550
- radius=int(self.line_width * 2),
551
- color=color,
552
- visible=frame_visibility[point_idx],
553
- alpha=opacity,
554
- )
555
-
556
- frames[frame_idx] = np.array(frame_img)
557
-
558
- return frames
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
7
  import os
 
 
 
 
8
  import numpy as np
9
+ import imageio
10
  import torch
 
11
 
12
  from matplotlib import cm
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+ import matplotlib.pyplot as plt
16
  from PIL import Image, ImageDraw
17
 
18
 
19
+ def read_video_from_path(path):
20
+ try:
21
+ reader = imageio.get_reader(path)
22
+ except Exception as e:
23
+ print("Error opening video file: ", e)
24
+ return None
25
+ frames = []
26
+ for i, im in enumerate(reader):
27
+ frames.append(np.array(im))
28
+ return np.stack(frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
+ def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True, color_alpha=None):
32
+ # Create a draw object
33
+ draw = ImageDraw.Draw(rgb)
34
+ # Calculate the bounding box of the circle
35
+ left_up_point = (coord[0] - radius, coord[1] - radius)
36
+ right_down_point = (coord[0] + radius, coord[1] + radius)
37
+ # Draw the circle
38
+ color = tuple(list(color) + [color_alpha if color_alpha is not None else 255])
39
+
40
+ draw.ellipse(
41
+ [left_up_point, right_down_point],
42
+ fill=tuple(color) if visible else None,
43
+ outline=tuple(color),
44
+ )
45
+ return rgb
46
 
 
47
 
48
+ def draw_line(rgb, coord_y, coord_x, color, linewidth):
49
+ draw = ImageDraw.Draw(rgb)
50
+ draw.line(
51
+ (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
52
+ fill=tuple(color),
53
+ width=linewidth,
54
+ )
55
+ return rgb
 
56
 
57
+
58
+ def add_weighted(rgb, alpha, original, beta, gamma):
59
+ return (rgb * alpha + original * beta + gamma).astype("uint8")
60
+
61
+
62
+ class Visualizer:
63
  def __init__(
64
  self,
65
  save_dir: str = "./results",
66
  grayscale: bool = False,
67
  pad_value: int = 0,
68
  fps: int = 10,
69
+ mode: str = "rainbow", # 'cool', 'optical_flow'
70
  linewidth: int = 2,
71
  show_first_frame: int = 10,
72
+ tracks_leave_trace: int = 0, # -1 for infinite
73
  ):
74
+ self.mode = mode
75
  self.save_dir = save_dir
76
+ if mode == "rainbow":
77
+ self.color_map = cm.get_cmap("gist_rainbow")
78
+ elif mode == "cool":
79
+ self.color_map = cm.get_cmap(mode)
80
+ self.show_first_frame = show_first_frame
81
+ self.grayscale = grayscale
82
+ self.tracks_leave_trace = tracks_leave_trace
83
+ self.pad_value = pad_value
84
+ self.linewidth = linewidth
85
  self.fps = fps
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def visualize(
88
  self,
89
  video: torch.Tensor, # (B,T,C,H,W)
90
  tracks: torch.Tensor, # (B,T,N,2)
91
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
92
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
93
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
94
  filename: str = "video",
95
+ writer=None, # tensorboard Summary Writer, used for visualization during training
96
+ step: int = 0,
97
+ query_frame=0,
98
  save_video: bool = True,
99
+ compensate_for_camera_motion: bool = False,
100
+ opacity: float = 1.0,
101
+ ):
102
+ if compensate_for_camera_motion:
103
+ assert segm_mask is not None
104
+ # if segm_mask is not None:
105
+ # coords = tracks[0, query_frame].round().long()
106
+ # segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
107
+
108
+ video = F.pad(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  video,
110
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
111
+ "constant",
112
+ 255,
113
  )
114
+ color_alpha = int(opacity * 255)
115
+ tracks = tracks + self.pad_value
116
 
117
+ if self.grayscale:
118
+ transform = transforms.Grayscale()
119
+ video = transform(video)
120
+ video = video.repeat(1, 1, 3, 1, 1)
121
 
122
+ res_video = self.draw_tracks_on_video(
123
+ video=video,
124
+ tracks=tracks,
 
 
 
 
125
  visibility=visibility,
126
+ segm_mask=segm_mask,
127
+ gt_tracks=gt_tracks,
128
  query_frame=query_frame,
129
+ compensate_for_camera_motion=compensate_for_camera_motion,
130
+ color_alpha=color_alpha,
131
  )
 
 
132
  if save_video:
133
+ self.save_video(res_video, filename=filename, writer=writer, step=step)
134
+ return res_video
135
+
136
+ def save_video(self, video, filename, writer=None, step=0):
137
+ if writer is not None:
138
+ writer.add_video(
139
+ filename,
140
+ video.to(torch.uint8),
141
+ global_step=step,
142
+ fps=self.fps,
143
+ )
144
+ else:
145
+ os.makedirs(self.save_dir, exist_ok=True)
146
+ wide_list = list(video.unbind(1))
147
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
148
 
149
+ # Prepare the video file path
150
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
 
 
 
151
 
152
+ # Create a writer object
153
+ video_writer = imageio.get_writer(save_path, fps=self.fps)
154
 
155
+ # Write frames to the video file
156
+ for frame in wide_list[2:-1]:
157
+ video_writer.append_data(frame)
 
 
158
 
159
+ video_writer.close()
160
 
161
+ print(f"Video saved to {save_path}")
 
162
 
163
  def draw_tracks_on_video(
164
  self,
165
  video: torch.Tensor,
166
  tracks: torch.Tensor,
167
  visibility: torch.Tensor = None,
168
+ segm_mask: torch.Tensor = None,
169
+ gt_tracks=None,
170
+ query_frame=0,
171
+ compensate_for_camera_motion=False,
172
+ color_alpha: int = 255,
173
+ ):
174
+ B, T, C, H, W = video.shape
175
+ _, _, N, D = tracks.shape
176
+
177
+ assert D == 2
178
+ assert C == 3
179
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
180
+ tracks = tracks[0].detach().cpu().numpy() # S, N, 2
181
+ if gt_tracks is not None:
182
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
183
+
184
+ res_video = []
185
+
186
+ # process input video
187
+ for rgb in video:
188
+ res_video.append(rgb.copy())
189
+ vector_colors = np.zeros((T, N, 3))
190
+
191
+ if self.mode == "optical_flow":
192
+ import flow_vis
193
+
194
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
195
+ elif segm_mask is None:
196
+ if self.mode == "rainbow":
197
+ y_min, y_max = (
198
+ tracks[query_frame, :, 1].min(),
199
+ tracks[query_frame, :, 1].max(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  norm = plt.Normalize(y_min, y_max)
202
+ for n in range(N):
203
+ if isinstance(query_frame, torch.Tensor):
204
+ query_frame_ = query_frame[n]
205
+ else:
206
+ query_frame_ = query_frame
207
+ color = self.color_map(norm(tracks[query_frame_, n, 1]))
208
+ color = np.array(color[:3])[None] * 255
209
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
210
+ else:
211
+ # color changes with time
212
+ for t in range(T):
213
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
214
+ vector_colors[t] = np.repeat(color, N, axis=0)
215
  else:
216
+ if self.mode == "rainbow":
217
+ vector_colors[:, segm_mask <= 0, :] = 255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ y_min, y_max = (
220
+ tracks[0, segm_mask > 0, 1].min(),
221
+ tracks[0, segm_mask > 0, 1].max(),
222
+ )
223
+ norm = plt.Normalize(y_min, y_max)
224
+ for n in range(N):
225
+ if segm_mask[n] > 0:
226
+ color = self.color_map(norm(tracks[0, n, 1]))
227
+ color = np.array(color[:3])[None] * 255
228
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
229
+
230
+ else:
231
+ # color changes with segm class
232
+ segm_mask = segm_mask.cpu()
233
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
234
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
235
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
236
+ vector_colors = np.repeat(color[None], T, axis=0)
237
+
238
+ # draw tracks
239
+ if self.tracks_leave_trace != 0:
240
+ for t in range(query_frame + 1, T):
241
+ first_ind = (
242
+ max(0, t - self.tracks_leave_trace)
243
+ if self.tracks_leave_trace >= 0
244
+ else 0
245
+ )
246
+ curr_tracks = tracks[first_ind : t + 1]
247
+ curr_colors = vector_colors[first_ind : t + 1]
248
+ if compensate_for_camera_motion:
249
+ diff = (
250
+ tracks[first_ind : t + 1, segm_mask <= 0]
251
+ - tracks[t : t + 1, segm_mask <= 0]
252
+ ).mean(1)[:, None]
253
+
254
+ curr_tracks = curr_tracks - diff
255
+ curr_tracks = curr_tracks[:, segm_mask > 0]
256
+ curr_colors = curr_colors[:, segm_mask > 0]
257
+
258
+ res_video[t] = self._draw_pred_tracks(
259
+ res_video[t],
260
+ curr_tracks,
261
+ curr_colors,
262
+ )
263
+ if gt_tracks is not None:
264
+ res_video[t] = self._draw_gt_tracks(
265
+ res_video[t], gt_tracks[first_ind : t + 1]
266
  )
267
 
268
+ # draw points
269
+ for t in range(T):
270
+ img = Image.fromarray(np.uint8(res_video[t]))
271
+ for i in range(N):
272
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
273
+ visibile = True
274
+ if visibility is not None:
275
+ visibile = visibility[0, t, i]
276
+ if coord[0] != 0 and coord[1] != 0:
277
+ if not compensate_for_camera_motion or (
278
+ compensate_for_camera_motion and segm_mask[i] > 0
279
+ ):
280
+ # img = draw_circle(
281
+ # img,
282
+ # coord=coord,
283
+ # radius=int(self.linewidth * 2),
284
+ # color=vector_colors[t, i].astype(int),
285
+ # visible=visibile,
286
+ # color_alpha=color_alpha,
287
+ # )
288
+
289
+ # coord_ = coord[t,i]
290
+ # draw a red cross
291
+ # if gt_tracks[0] > 0 and gt_tracks[1] > 0:
292
+ if visibile:
293
+ length = self.linewidth * 3
294
+ coord_y = ((coord[0]) + length, (coord[1]) + length)
295
+ coord_x = ((coord[0]) - length, (coord[1]) - length)
296
+ rgb = draw_line(
297
+ img,
298
+ coord_y,
299
+ coord_x,
300
+ vector_colors[t, i].astype(int),
301
+ self.linewidth,
302
+ )
303
+ coord_y = ((coord[0]) - length, (coord[1]) + length)
304
+ coord_x = ((coord[0]) + length, (coord[1]) - length)
305
+ rgb = draw_line(
306
+ img,
307
+ coord_y,
308
+ coord_x,
309
+ vector_colors[t, i].astype(int),
310
+ self.linewidth,
311
+ )
312
+ res_video[t] = np.array(img)
313
+
314
+ # construct the final rgb sequence
315
+ if self.show_first_frame > 0:
316
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
317
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
318
+
319
+ def _draw_pred_tracks(
320
+ self,
321
+ rgb: np.ndarray, # H x W x 3
322
+ tracks: np.ndarray, # T x 2
323
+ vector_colors: np.ndarray,
324
+ alpha: float = 0.5,
325
+ ):
326
+ T, N, _ = tracks.shape
327
+ rgb = Image.fromarray(np.uint8(rgb))
328
+ for s in range(T - 1):
329
+ vector_color = vector_colors[s]
330
+ original = rgb.copy()
331
+ alpha = (s / T) ** 2
332
+ for i in range(N):
333
+ coord_y = ((tracks[s, i, 0]), (tracks[s, i, 1]))
334
+ coord_x = ((tracks[s + 1, i, 0]), (tracks[s + 1, i, 1]))
335
+ if coord_y[0] != 0 and coord_y[1] != 0:
336
+ rgb = draw_line(
337
+ rgb,
338
+ coord_y,
339
+ coord_x,
340
+ vector_color[i].astype(int),
341
+ self.linewidth,
342
  )
343
+ if self.tracks_leave_trace > 0:
344
+ rgb = Image.fromarray(
345
+ np.uint8(
346
+ add_weighted(
347
+ np.array(rgb), alpha, np.array(original), 1 - alpha, 0
348
+ )
349
  )
350
  )
351
+ rgb = np.array(rgb)
352
+ return rgb
353
 
354
+ def _draw_gt_tracks(
 
 
355
  self,
356
+ rgb: np.ndarray, # H x W x 3,
357
+ gt_tracks: np.ndarray, # T x 2
358
+ vector_colors: np.ndarray = None,
359
+ ):
360
+ T, N, _ = gt_tracks.shape
361
+ if vector_colors is None:
362
+ color = np.array((211, 0, 0))
363
+ rgb = Image.fromarray(np.uint8(rgb))
364
+ for t in range(T):
365
+ if vector_colors is not None:
366
+ vector_color = vector_colors[t]
367
+ for i in range(N):
368
+ if vector_colors is not None:
369
+ color = vector_color[i].astype(int)
370
+ gt_tracks = gt_tracks[t][i]
371
+ # draw a red cross
372
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
373
+ length = self.linewidth * 3
374
+ coord_y = ((gt_tracks[0]) + length, (gt_tracks[1]) + length)
375
+ coord_x = ((gt_tracks[0]) - length, (gt_tracks[1]) - length)
376
+ rgb = draw_line(
377
+ rgb,
378
+ coord_y,
379
+ coord_x,
380
+ color,
381
+ self.linewidth,
382
+ )
383
+ coord_y = ((gt_tracks[0]) - length, (gt_tracks[1]) + length)
384
+ coord_x = ((gt_tracks[0]) + length, (gt_tracks[1]) - length)
385
+ rgb = draw_line(
386
+ rgb,
387
+ coord_y,
388
+ coord_x,
389
+ color,
390
+ self.linewidth,
391
+ )
392
+ rgb = np.array(rgb)
393
+ return rgb