nikkar commited on
Commit
a4369ad
·
verified ·
1 Parent(s): e17af04

Create visualizer.py

Browse files
Files changed (1) hide show
  1. visualizer.py +363 -0
visualizer.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import numpy as np
9
+ import imageio
10
+ import torch
11
+
12
+ from matplotlib import cm
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+ import matplotlib.pyplot as plt
16
+ from PIL import Image, ImageDraw
17
+
18
+
19
+ def read_video_from_path(path):
20
+ try:
21
+ reader = imageio.get_reader(path)
22
+ except Exception as e:
23
+ print("Error opening video file: ", e)
24
+ return None
25
+ frames = []
26
+ for i, im in enumerate(reader):
27
+ frames.append(np.array(im))
28
+ return np.stack(frames)
29
+
30
+
31
+ def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True, color_alpha=None):
32
+ # Create a draw object
33
+ draw = ImageDraw.Draw(rgb)
34
+ # Calculate the bounding box of the circle
35
+ left_up_point = (coord[0] - radius, coord[1] - radius)
36
+ right_down_point = (coord[0] + radius, coord[1] + radius)
37
+ # Draw the circle
38
+ color = tuple(list(color) + [color_alpha if color_alpha is not None else 255])
39
+
40
+ draw.ellipse(
41
+ [left_up_point, right_down_point],
42
+ fill=tuple(color) if visible else None,
43
+ outline=tuple(color),
44
+ )
45
+ return rgb
46
+
47
+
48
+ def draw_line(rgb, coord_y, coord_x, color, linewidth):
49
+ draw = ImageDraw.Draw(rgb)
50
+ draw.line(
51
+ (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
52
+ fill=tuple(color),
53
+ width=linewidth,
54
+ )
55
+ return rgb
56
+
57
+
58
+ def add_weighted(rgb, alpha, original, beta, gamma):
59
+ return (rgb * alpha + original * beta + gamma).astype("uint8")
60
+
61
+
62
+ class Visualizer:
63
+ def __init__(
64
+ self,
65
+ save_dir: str = "./results",
66
+ grayscale: bool = False,
67
+ pad_value: int = 0,
68
+ fps: int = 10,
69
+ mode: str = "rainbow", # 'cool', 'optical_flow'
70
+ linewidth: int = 2,
71
+ show_first_frame: int = 10,
72
+ tracks_leave_trace: int = 0, # -1 for infinite
73
+ ):
74
+ self.mode = mode
75
+ self.save_dir = save_dir
76
+ if mode == "rainbow":
77
+ self.color_map = cm.get_cmap("gist_rainbow")
78
+ elif mode == "cool":
79
+ self.color_map = cm.get_cmap(mode)
80
+ self.show_first_frame = show_first_frame
81
+ self.grayscale = grayscale
82
+ self.tracks_leave_trace = tracks_leave_trace
83
+ self.pad_value = pad_value
84
+ self.linewidth = linewidth
85
+ self.fps = fps
86
+
87
+ def visualize(
88
+ self,
89
+ video: torch.Tensor, # (B,T,C,H,W)
90
+ tracks: torch.Tensor, # (B,T,N,2)
91
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
92
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
93
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
94
+ filename: str = "video",
95
+ writer=None, # tensorboard Summary Writer, used for visualization during training
96
+ step: int = 0,
97
+ query_frame=0,
98
+ save_video: bool = True,
99
+ compensate_for_camera_motion: bool = False,
100
+ opacity: float = 1.0,
101
+ ):
102
+ if compensate_for_camera_motion:
103
+ assert segm_mask is not None
104
+ if segm_mask is not None:
105
+ coords = tracks[0, query_frame].round().long()
106
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
107
+
108
+ video = F.pad(
109
+ video,
110
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
111
+ "constant",
112
+ 255,
113
+ )
114
+ color_alpha = int(opacity * 255)
115
+ tracks = tracks + self.pad_value
116
+
117
+ if self.grayscale:
118
+ transform = transforms.Grayscale()
119
+ video = transform(video)
120
+ video = video.repeat(1, 1, 3, 1, 1)
121
+
122
+ res_video = self.draw_tracks_on_video(
123
+ video=video,
124
+ tracks=tracks,
125
+ visibility=visibility,
126
+ segm_mask=segm_mask,
127
+ gt_tracks=gt_tracks,
128
+ query_frame=query_frame,
129
+ compensate_for_camera_motion=compensate_for_camera_motion,
130
+ color_alpha=color_alpha,
131
+ )
132
+ if save_video:
133
+ self.save_video(res_video, filename=filename, writer=writer, step=step)
134
+ return res_video
135
+
136
+ def save_video(self, video, filename, writer=None, step=0):
137
+ if writer is not None:
138
+ writer.add_video(
139
+ filename,
140
+ video.to(torch.uint8),
141
+ global_step=step,
142
+ fps=self.fps,
143
+ )
144
+ else:
145
+ os.makedirs(self.save_dir, exist_ok=True)
146
+ wide_list = list(video.unbind(1))
147
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
148
+
149
+ # Prepare the video file path
150
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
151
+
152
+ # Create a writer object
153
+ video_writer = imageio.get_writer(save_path, fps=self.fps)
154
+
155
+ # Write frames to the video file
156
+ for frame in wide_list[2:-1]:
157
+ video_writer.append_data(frame)
158
+
159
+ video_writer.close()
160
+
161
+ print(f"Video saved to {save_path}")
162
+
163
+ def draw_tracks_on_video(
164
+ self,
165
+ video: torch.Tensor,
166
+ tracks: torch.Tensor,
167
+ visibility: torch.Tensor = None,
168
+ segm_mask: torch.Tensor = None,
169
+ gt_tracks=None,
170
+ query_frame=0,
171
+ compensate_for_camera_motion=False,
172
+ color_alpha: int = 255,
173
+ ):
174
+ B, T, C, H, W = video.shape
175
+ _, _, N, D = tracks.shape
176
+
177
+ assert D == 2
178
+ assert C == 3
179
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
180
+ tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
181
+ if gt_tracks is not None:
182
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
183
+
184
+ res_video = []
185
+
186
+ # process input video
187
+ for rgb in video:
188
+ res_video.append(rgb.copy())
189
+ vector_colors = np.zeros((T, N, 3))
190
+
191
+ if self.mode == "optical_flow":
192
+ import flow_vis
193
+
194
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
195
+ elif segm_mask is None:
196
+ if self.mode == "rainbow":
197
+ y_min, y_max = (
198
+ tracks[query_frame, :, 1].min(),
199
+ tracks[query_frame, :, 1].max(),
200
+ )
201
+ norm = plt.Normalize(y_min, y_max)
202
+ for n in range(N):
203
+ if isinstance(query_frame, torch.Tensor):
204
+ query_frame_ = query_frame[n]
205
+ else:
206
+ query_frame_ = query_frame
207
+ color = self.color_map(norm(tracks[query_frame_, n, 1]))
208
+ color = np.array(color[:3])[None] * 255
209
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
210
+ else:
211
+ # color changes with time
212
+ for t in range(T):
213
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
214
+ vector_colors[t] = np.repeat(color, N, axis=0)
215
+ else:
216
+ if self.mode == "rainbow":
217
+ vector_colors[:, segm_mask <= 0, :] = 255
218
+
219
+ y_min, y_max = (
220
+ tracks[0, segm_mask > 0, 1].min(),
221
+ tracks[0, segm_mask > 0, 1].max(),
222
+ )
223
+ norm = plt.Normalize(y_min, y_max)
224
+ for n in range(N):
225
+ if segm_mask[n] > 0:
226
+ color = self.color_map(norm(tracks[0, n, 1]))
227
+ color = np.array(color[:3])[None] * 255
228
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
229
+
230
+ else:
231
+ # color changes with segm class
232
+ segm_mask = segm_mask.cpu()
233
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
234
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
235
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
236
+ vector_colors = np.repeat(color[None], T, axis=0)
237
+
238
+ # draw tracks
239
+ if self.tracks_leave_trace != 0:
240
+ for t in range(query_frame + 1, T):
241
+ first_ind = (
242
+ max(0, t - self.tracks_leave_trace)
243
+ if self.tracks_leave_trace >= 0
244
+ else 0
245
+ )
246
+ curr_tracks = tracks[first_ind : t + 1]
247
+ curr_colors = vector_colors[first_ind : t + 1]
248
+ if compensate_for_camera_motion:
249
+ diff = (
250
+ tracks[first_ind : t + 1, segm_mask <= 0]
251
+ - tracks[t : t + 1, segm_mask <= 0]
252
+ ).mean(1)[:, None]
253
+
254
+ curr_tracks = curr_tracks - diff
255
+ curr_tracks = curr_tracks[:, segm_mask > 0]
256
+ curr_colors = curr_colors[:, segm_mask > 0]
257
+
258
+ res_video[t] = self._draw_pred_tracks(
259
+ res_video[t],
260
+ curr_tracks,
261
+ curr_colors,
262
+ )
263
+ if gt_tracks is not None:
264
+ res_video[t] = self._draw_gt_tracks(
265
+ res_video[t], gt_tracks[first_ind : t + 1]
266
+ )
267
+
268
+ # draw points
269
+ for t in range(T):
270
+ img = Image.fromarray(np.uint8(res_video[t]))
271
+ for i in range(N):
272
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
273
+ visibile = True
274
+ if visibility is not None:
275
+ visibile = visibility[0, t, i]
276
+ if coord[0] != 0 and coord[1] != 0:
277
+ if not compensate_for_camera_motion or (
278
+ compensate_for_camera_motion and segm_mask[i] > 0
279
+ ):
280
+ img = draw_circle(
281
+ img,
282
+ coord=coord,
283
+ radius=int(self.linewidth * 2),
284
+ color=vector_colors[t, i].astype(int),
285
+ visible=visibile,
286
+ color_alpha=color_alpha,
287
+ )
288
+ res_video[t] = np.array(img)
289
+
290
+ # construct the final rgb sequence
291
+ if self.show_first_frame > 0:
292
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
293
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
294
+
295
+ def _draw_pred_tracks(
296
+ self,
297
+ rgb: np.ndarray, # H x W x 3
298
+ tracks: np.ndarray, # T x 2
299
+ vector_colors: np.ndarray,
300
+ alpha: float = 0.5,
301
+ ):
302
+ T, N, _ = tracks.shape
303
+ rgb = Image.fromarray(np.uint8(rgb))
304
+ for s in range(T - 1):
305
+ vector_color = vector_colors[s]
306
+ original = rgb.copy()
307
+ alpha = (s / T) ** 2
308
+ for i in range(N):
309
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
310
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
311
+ if coord_y[0] != 0 and coord_y[1] != 0:
312
+ rgb = draw_line(
313
+ rgb,
314
+ coord_y,
315
+ coord_x,
316
+ vector_color[i].astype(int),
317
+ self.linewidth,
318
+ )
319
+ if self.tracks_leave_trace > 0:
320
+ rgb = Image.fromarray(
321
+ np.uint8(
322
+ add_weighted(
323
+ np.array(rgb), alpha, np.array(original), 1 - alpha, 0
324
+ )
325
+ )
326
+ )
327
+ rgb = np.array(rgb)
328
+ return rgb
329
+
330
+ def _draw_gt_tracks(
331
+ self,
332
+ rgb: np.ndarray, # H x W x 3,
333
+ gt_tracks: np.ndarray, # T x 2
334
+ ):
335
+ T, N, _ = gt_tracks.shape
336
+ color = np.array((211, 0, 0))
337
+ rgb = Image.fromarray(np.uint8(rgb))
338
+ for t in range(T):
339
+ for i in range(N):
340
+ gt_tracks = gt_tracks[t][i]
341
+ # draw a red cross
342
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
343
+ length = self.linewidth * 3
344
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
345
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
346
+ rgb = draw_line(
347
+ rgb,
348
+ coord_y,
349
+ coord_x,
350
+ color,
351
+ self.linewidth,
352
+ )
353
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
354
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
355
+ rgb = draw_line(
356
+ rgb,
357
+ coord_y,
358
+ coord_x,
359
+ color,
360
+ self.linewidth,
361
+ )
362
+ rgb = np.array(rgb)
363
+ return rgb