Seokju Cho commited on
Commit
f1586f7
·
1 Parent(s): 058b9ed

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uuid
4
+
5
+ import gradio as gr
6
+ import mediapy
7
+ import numpy as np
8
+ import cv2
9
+ import matplotlib
10
+ import torch
11
+
12
+ from locotrack_pytorch.models.locotrack_model import load_model
13
+ from viz_utils import paint_point_track
14
+
15
+
16
+ PREVIEW_WIDTH = 768 # Width of the preview video
17
+ VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
18
+ POINT_SIZE = 4 # Size of the query point in the preview video
19
+ FRAME_LIMIT = 300 # Limit the number of frames to process
20
+
21
+
22
+ def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
23
+ print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
24
+
25
+ current_frame = video_queried_preview[int(frame_num)]
26
+
27
+ # Get the mouse click
28
+ query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num))
29
+
30
+ # Choose the color for the point from matplotlib colormap
31
+ color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20)
32
+ color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
33
+ print(f"Color: {color}")
34
+ query_points_color[int(frame_num)].append(color)
35
+
36
+ # Draw the point on the frame
37
+ x, y = evt.index
38
+ current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1)
39
+
40
+ # Update the frame
41
+ video_queried_preview[int(frame_num)] = current_frame_draw
42
+
43
+ # Update the query count
44
+ query_count += 1
45
+ return (
46
+ current_frame_draw, # Updated frame for preview
47
+ video_queried_preview, # Updated preview video
48
+ query_points, # Updated query points
49
+ query_points_color, # Updated query points color
50
+ query_count # Updated query count
51
+ )
52
+
53
+
54
+ def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
55
+ if len(query_points[int(frame_num)]) == 0:
56
+ return (
57
+ video_queried_preview[int(frame_num)],
58
+ video_queried_preview,
59
+ query_points,
60
+ query_points_color,
61
+ query_count
62
+ )
63
+
64
+ # Get the last point
65
+ query_points[int(frame_num)].pop(-1)
66
+ query_points_color[int(frame_num)].pop(-1)
67
+
68
+ # Redraw the frame
69
+ current_frame_draw = video_preview[int(frame_num)].copy()
70
+ for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
71
+ x, y, _ = point
72
+ current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
73
+
74
+ # Update the query count
75
+ query_count -= 1
76
+
77
+ # Update the frame
78
+ video_queried_preview[int(frame_num)] = current_frame_draw
79
+ return (
80
+ current_frame_draw, # Updated frame for preview
81
+ video_queried_preview, # Updated preview video
82
+ query_points, # Updated query points
83
+ query_points_color, # Updated query points color
84
+ query_count # Updated query count
85
+ )
86
+
87
+
88
+ def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
89
+ query_count -= len(query_points[int(frame_num)])
90
+
91
+ query_points[int(frame_num)] = []
92
+ query_points_color[int(frame_num)] = []
93
+
94
+ video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
95
+
96
+ return (
97
+ video_preview[int(frame_num)], # Set the preview frame to the original frame
98
+ video_queried_preview,
99
+ query_points, # Cleared query points
100
+ query_points_color, # Cleared query points color
101
+ query_count # New query count
102
+ )
103
+
104
+
105
+
106
+ def clear_all_fn(frame_num, video_preview):
107
+ return (
108
+ video_preview[int(frame_num)],
109
+ video_preview.copy(),
110
+ [[] for _ in range(len(video_preview))],
111
+ [[] for _ in range(len(video_preview))],
112
+ 0
113
+ )
114
+
115
+
116
+ def choose_frame(frame_num, video_preview_array):
117
+ return video_preview_array[int(frame_num)]
118
+
119
+
120
+ def extract_feature(video_input, model_size="small"):
121
+ device = "cuda" if torch.cuda.is_available() else "cpu"
122
+ dtype = torch.bfloat16 if device == "cuda" else torch.float16
123
+
124
+ model = load_model(model_size=model_size).to(device)
125
+
126
+ video_input = (video_input / 255.0) * 2 - 1
127
+ video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
128
+
129
+ with torch.autocast(device_type=device, dtype=dtype):
130
+ with torch.no_grad():
131
+ feature = model.get_feature_grids(video_input)
132
+
133
+ return feature
134
+
135
+
136
+ def preprocess_video_input(video_path, model_size):
137
+ video_arr = mediapy.read_video(video_path)
138
+ video_fps = video_arr.metadata.fps
139
+ num_frames = video_arr.shape[0]
140
+ if num_frames > FRAME_LIMIT:
141
+ gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5)
142
+ video_arr = video_arr[:FRAME_LIMIT]
143
+ num_frames = FRAME_LIMIT
144
+
145
+ # Resize to preview size for faster processing, width = PREVIEW_WIDTH
146
+ height, width = video_arr.shape[1:3]
147
+ new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
148
+
149
+ preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
150
+ input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
151
+
152
+ preview_video = np.array(preview_video)
153
+ input_video = np.array(input_video)
154
+
155
+ video_feature = extract_feature(input_video, model_size)
156
+
157
+ return (
158
+ video_arr, # Original video
159
+ preview_video, # Original preview video, resized for faster processing
160
+ preview_video.copy(), # Copy of preview video for visualization
161
+ input_video, # Resized video input for model
162
+ video_feature, # Extracted feature
163
+ video_fps, # Set the video FPS
164
+ gr.update(open=False), # Close the video input drawer
165
+ model_size, # Set the model size
166
+ preview_video[0], # Set the preview frame to the first frame
167
+ gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True), # Set slider interactive
168
+ [[] for _ in range(num_frames)], # Set query_points to empty
169
+ [[] for _ in range(num_frames)], # Set query_points_color to empty
170
+ [[] for _ in range(num_frames)],
171
+ 0, # Set query count to 0
172
+ gr.update(interactive=True), # Make the buttons interactive
173
+ gr.update(interactive=True),
174
+ gr.update(interactive=True),
175
+ gr.update(interactive=True),
176
+ )
177
+
178
+
179
+ def track(
180
+ model_size,
181
+ video_preview,
182
+ video_input,
183
+ video_feature,
184
+ video_fps,
185
+ query_points,
186
+ query_points_color,
187
+ query_count,
188
+ ):
189
+ if query_count == 0:
190
+ gr.Warning("Please add query points before tracking.", duration=5)
191
+ return None
192
+
193
+ device = "cuda" if torch.cuda.is_available() else "cpu"
194
+ dtype = torch.bfloat16 if device == "cuda" else torch.float16
195
+
196
+ # Convert query points to tensor, normalize to input resolution
197
+ query_points_tensor = []
198
+ for frame_points in query_points:
199
+ query_points_tensor.extend(frame_points)
200
+
201
+ query_points_tensor = torch.tensor(query_points_tensor).float()
202
+ query_points_tensor *= torch.tensor([
203
+ VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1
204
+ ]) / torch.tensor([
205
+ [video_preview.shape[2], video_preview.shape[1], 1]
206
+ ])
207
+ query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) # xyt -> tyx
208
+
209
+ # Preprocess video input
210
+ video_input = (video_input / 255.0) * 2 - 1
211
+ video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
212
+
213
+ model = load_model(model_size=model_size).to(device)
214
+ with torch.autocast(device_type=device, dtype=dtype):
215
+ with torch.no_grad():
216
+ output = model(video_input, query_points_tensor, feature_grids=video_feature)
217
+
218
+ tracks = output['tracks'][0].cpu()
219
+ tracks = tracks * torch.tensor([
220
+ video_preview.shape[2], video_preview.shape[1]
221
+ ]) / torch.tensor([
222
+ VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]
223
+ ])
224
+ tracks = tracks.numpy()
225
+
226
+
227
+ occlusion_logits = output['occlusion']
228
+ pred_occ = torch.sigmoid(occlusion_logits)
229
+ if 'expected_dist' in output:
230
+ expected_dist = output['expected_dist']
231
+ pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(expected_dist))
232
+
233
+ pred_occ = (pred_occ > 0.5)[0].cpu().numpy()
234
+
235
+ # make color array
236
+ colors = []
237
+ for frame_colors in query_points_color:
238
+ colors.extend(frame_colors)
239
+ colors = np.array(colors)
240
+
241
+ painted_video = paint_point_track(
242
+ video_preview,
243
+ tracks,
244
+ ~pred_occ,
245
+ colors,
246
+ )
247
+
248
+ # save video
249
+ video_file_name = uuid.uuid4().hex + ".mp4"
250
+ video_path = os.path.join(os.path.dirname(__file__), "tmp")
251
+ video_file_path = os.path.join(video_path, video_file_name)
252
+ os.makedirs(video_path, exist_ok=True)
253
+
254
+ mediapy.write_video(video_file_path, painted_video, fps=video_fps)
255
+
256
+ return video_file_path
257
+
258
+
259
+ with gr.Blocks() as demo:
260
+ video = gr.State()
261
+ video_queried_preview = gr.State()
262
+ video_preview = gr.State()
263
+ video_input = gr.State()
264
+ video_feautre = gr.State()
265
+ video_fps = gr.State(24)
266
+ model_size = gr.State("small")
267
+
268
+ query_points = gr.State([])
269
+ query_points_color = gr.State([])
270
+ is_tracked_query = gr.State([])
271
+ query_count = gr.State(0)
272
+
273
+ gr.Markdown("# LocoTrack Demo")
274
+ gr.Markdown("This is an interactive demo for LocoTrack. For more details, please refer to the [GitHub repository](https://github.com/KU-CVLAB/LocoTrack) or the [paper](https://arxiv.org/abs/2407.15420).")
275
+
276
+ gr.Markdown("## First step: Choose the model size and upload your video")
277
+ with gr.Row():
278
+ with gr.Accordion("Your video input", open=True) as video_in_drawer:
279
+ model_size_selection = gr.Radio(
280
+ label="Model Size",
281
+ choices=["small", "base"],
282
+ value="small",
283
+ )
284
+ video_in = gr.Video(label="Video Input", format="mp4")
285
+
286
+ gr.Markdown("## Second step: Add query points to track")
287
+ with gr.Row():
288
+
289
+ with gr.Column():
290
+ with gr.Row():
291
+ query_frames = gr.Slider(
292
+ minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
293
+ with gr.Row():
294
+ undo = gr.Button("Undo", interactive=False)
295
+ clear_frame = gr.Button("Clear Frame", interactive=False)
296
+ clear_all = gr.Button("Clear All", interactive=False)
297
+
298
+ with gr.Row():
299
+ current_frame = gr.Image(
300
+ label="Click to add query points",
301
+ type="numpy",
302
+ interactive=False
303
+ )
304
+
305
+ with gr.Row():
306
+ track_button = gr.Button("Track", interactive=False)
307
+
308
+ with gr.Column():
309
+ output_video = gr.Video(
310
+ label="Output Video",
311
+ interactive=False,
312
+ autoplay=True,
313
+ loop=True,
314
+ )
315
+
316
+ video_in.upload(
317
+ fn = preprocess_video_input,
318
+ inputs = [video_in, model_size_selection],
319
+ outputs = [
320
+ video,
321
+ video_preview,
322
+ video_queried_preview,
323
+ video_input,
324
+ video_feautre,
325
+ video_fps,
326
+ video_in_drawer,
327
+ model_size,
328
+ current_frame,
329
+ query_frames,
330
+ query_points,
331
+ query_points_color,
332
+ is_tracked_query,
333
+ query_count,
334
+ undo,
335
+ clear_frame,
336
+ clear_all,
337
+ track_button,
338
+ ],
339
+ queue = False
340
+ )
341
+
342
+ query_frames.change(
343
+ fn = choose_frame,
344
+ inputs = [query_frames, video_queried_preview],
345
+ outputs = [
346
+ current_frame,
347
+ ],
348
+ queue = False
349
+ )
350
+
351
+ current_frame.select(
352
+ fn = get_point,
353
+ inputs = [
354
+ query_frames,
355
+ video_queried_preview,
356
+ query_points,
357
+ query_points_color,
358
+ query_count,
359
+ ],
360
+ outputs = [
361
+ current_frame,
362
+ video_queried_preview,
363
+ query_points,
364
+ query_points_color,
365
+ query_count
366
+ ],
367
+ queue = False
368
+ )
369
+
370
+ undo.click(
371
+ fn = undo_point,
372
+ inputs = [
373
+ query_frames,
374
+ video_preview,
375
+ video_queried_preview,
376
+ query_points,
377
+ query_points_color,
378
+ query_count
379
+ ],
380
+ outputs = [
381
+ current_frame,
382
+ video_queried_preview,
383
+ query_points,
384
+ query_points_color,
385
+ query_count
386
+ ],
387
+ queue = False
388
+ )
389
+
390
+ clear_frame.click(
391
+ fn = clear_frame_fn,
392
+ inputs = [
393
+ query_frames,
394
+ video_preview,
395
+ video_queried_preview,
396
+ query_points,
397
+ query_points_color,
398
+ query_count
399
+ ],
400
+ outputs = [
401
+ current_frame,
402
+ video_queried_preview,
403
+ query_points,
404
+ query_points_color,
405
+ query_count
406
+ ],
407
+ queue = False
408
+ )
409
+
410
+ clear_all.click(
411
+ fn = clear_all_fn,
412
+ inputs = [
413
+ query_frames,
414
+ video_preview,
415
+ ],
416
+ outputs = [
417
+ current_frame,
418
+ video_queried_preview,
419
+ query_points,
420
+ query_points_color,
421
+ query_count
422
+ ],
423
+ queue = False
424
+ )
425
+
426
+ track_button.click(
427
+ fn = track,
428
+ inputs = [
429
+ model_size,
430
+ video_preview,
431
+ video_input,
432
+ video_feautre,
433
+ video_fps,
434
+ query_points,
435
+ query_points_color,
436
+ query_count,
437
+ ],
438
+ outputs = [
439
+ output_video,
440
+ ],
441
+ queue = True,
442
+ )
443
+
444
+ demo.launch(show_api=False, show_error=True, debug=True)
locotrack_pytorch/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Implementation of LocoTrack
2
+
3
+ ## Preparing the Environment
4
+
5
+ ```bash
6
+ git clone https://github.com/google-research/kubric.git
7
+
8
+ conda create -n locotrack-pytorch python=3.11
9
+ conda activate locotrack-pytorch
10
+
11
+ pip install torch torchvision torchaudio lightning==2.3.3 tensorflow_datasets tensorflow matplotlib mediapy tensorflow_graphics einshape wandb
12
+ ```
13
+
14
+ ## LocoTrack Evaluation
15
+
16
+ ### 1. Download Pre-trained Weights
17
+
18
+ To evaluate LocoTrack on the benchmarks, first download the pre-trained weights.
19
+
20
+ | Model | Pre-trained Weights |
21
+ |-------------|---------------------|
22
+ | LocoTrack-S | [Link](https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_small.ckpt) |
23
+ | LocoTrack-B | [Link](https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_base.ckpt) |
24
+
25
+ ### 2. Adjust the Config File
26
+
27
+ In `config/default.ini` (or any other config file), add the path to the evaluation datasets to `[TRAINING]-val_dataset_path`. Additionally, adjust the model size for evaluation in `[MODEL]-model_kwargs-model_size`.
28
+
29
+ ### 3. Run Evaluation
30
+
31
+ To evaluate the LocoTrack model, use the `experiment.py` script with the following command-line arguments:
32
+
33
+ ```bash
34
+ python experiment.py --config config/default.ini --mode eval_{dataset_to_eval_1}_..._{dataset_to_eval_N}[_q_first] --ckpt_path /path/to/checkpoint --save_path ./path_to_save_checkpoints/
35
+ ```
36
+
37
+ - `--config`: Specifies the path to the configuration file. Default is `config/default.ini`.
38
+ - `--mode`: Specifies the mode to run the script. Use `eval` to perform evaluation. You can also include additional options for query first mode (`q_first`), and the name of the evaluation datasets. For example:
39
+ - Evaluation of the DAVIS dataset: `eval_davis`
40
+ - Evaluation of DAVIS and RoboTAP in query first mode: `eval_davis_robotap_q_first`
41
+ - `--ckpt_path`: Specifies the path to the checkpoint file. If not provided, the script will use the default checkpoint.
42
+ - `--save_path`: Specifies the path to save logs.
43
+
44
+ Replace `/path/to/checkpoint` with the actual path to your checkpoint file. This command will run the evaluation process and save the results in the specified `save_path`.
45
+
46
+ ## LocoTrack Training
47
+
48
+ ### Training Dataset Preparation
49
+
50
+ Download the panning-MOVi-E dataset used for training (approximately 273GB) from Huggingface using the following script. Git LFS should be installed to download the dataset. To install Git LFS, please refer to this [link](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage?platform=linux). Additionally, downloading instructions for the Huggingface dataset are available at this [link](https://huggingface.co/docs/hub/en/datasets-downloading).
51
+
52
+ ```bash
53
+ git clone [email protected]:datasets/hamacojr/LocoTrack-panning-MOVi-E
54
+ ```
55
+
56
+ ### Training Script
57
+
58
+ Add the path to the downloaded panning-MOVi-E to the `[TRAINING]-kubric_dir` entry in `config/default.ini` (or any other config file). Optionally, for efficient training, change `[TRAINING]-precision` in the config file to `bf16-mixed` to use `bfloat16`. Then, run the training with the following script:
59
+
60
+ ```bash
61
+ python experiment.py --config config/default.ini --mode train_davis --save_path ./path_to_save_checkpoints/
62
+ ```
locotrack_pytorch/config/default.ini ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [TRAINING]
2
+ val_dataset_path = {"davis": "/home/seokjuc/sensei-fs-link/tapvid/tapvid_davis/tapvid_davis.pkl", "robotics": "", "kinetics": "", "robotap": ""}
3
+ kubric_dir = ./kubric
4
+ precision = 32
5
+ batch_size = 4
6
+ val_check_interval = 1000
7
+ log_every_n_steps = 5
8
+ gradient_clip_val = 1.0
9
+ max_steps = 300000
10
+
11
+ [MODEL]
12
+ model_kwargs = {"model_size": "base", "num_pips_iter": 4}
13
+ model_forward_kwargs = {"refinement_resolutions": ((256, 256),), "query_chunk_size": 256}
14
+
15
+ [LOSS]
16
+ loss_name = tapir_loss
17
+ loss_kwargs = {}
18
+
19
+ [OPTIMIZER]
20
+ optimizer_name = AdamW
21
+ optimizer_kwargs = {"lr": 1e-3, "weight_decay": 1e-3, "betas": (0.9, 0.95)}
22
+
23
+ [SCHEDULER]
24
+ scheduler_name = OneCycleLR
25
+ scheduler_kwargs = {"max_lr": 1e-3, "pct_start": 0.003, "total_steps": 300000}
locotrack_pytorch/data/evaluation_datasets.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Evaluation dataset creation functions."""
17
+
18
+ import csv
19
+ import functools
20
+ import io
21
+ import os
22
+ from os import path
23
+ import pickle
24
+ import random
25
+ from typing import Iterable, Mapping, Optional, Tuple, Union
26
+
27
+ from absl import logging
28
+
29
+ import mediapy as media
30
+ import numpy as np
31
+ from PIL import Image
32
+ import scipy.io as sio
33
+ import tensorflow as tf
34
+ import tensorflow_datasets as tfds
35
+
36
+ from models.utils import convert_grid_coordinates
37
+
38
+ DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
39
+
40
+
41
+ def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
42
+ """Resize a video to output_size."""
43
+ # If you have a GPU, consider replacing this with a GPU-enabled resize op,
44
+ # such as a jitted jax.image.resize. It will make things faster.
45
+ return media.resize_video(video, output_size)
46
+
47
+
48
+ def compute_tapvid_metrics(
49
+ query_points: np.ndarray,
50
+ gt_occluded: np.ndarray,
51
+ gt_tracks: np.ndarray,
52
+ pred_occluded: np.ndarray,
53
+ pred_tracks: np.ndarray,
54
+ query_mode: str,
55
+ get_trackwise_metrics: bool = False,
56
+ ) -> Mapping[str, np.ndarray]:
57
+ """Computes TAP-Vid metrics (Jaccard, Pts.
58
+
59
+ Within Thresh, Occ.
60
+
61
+ Acc.)
62
+
63
+ See the TAP-Vid paper for details on the metric computation. All inputs are
64
+ given in raster coordinates. The first three arguments should be the direct
65
+ outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
66
+ The paper metrics assume these are scaled relative to 256x256 images.
67
+ pred_occluded and pred_tracks are your algorithm's predictions.
68
+
69
+ This function takes a batch of inputs, and computes metrics separately for
70
+ each video. The metrics for the full benchmark are a simple mean of the
71
+ metrics across the full set of videos. These numbers are between 0 and 1,
72
+ but the paper multiplies them by 100 to ease reading.
73
+
74
+ Args:
75
+ query_points: The query points, an in the format [t, y, x]. Its size is
76
+ [b, n, 3], where b is the batch size and n is the number of queries
77
+ gt_occluded: A boolean array of shape [b, n, t], where t is the number of
78
+ frames. True indicates that the point is occluded.
79
+ gt_tracks: The target points, of shape [b, n, t, 2]. Each point is in the
80
+ format [x, y]
81
+ pred_occluded: A boolean array of predicted occlusions, in the same format
82
+ as gt_occluded.
83
+ pred_tracks: An array of track predictions from your algorithm, in the same
84
+ format as gt_tracks.
85
+ query_mode: Either 'first' or 'strided', depending on how queries are
86
+ sampled. If 'first', we assume the prior knowledge that all points
87
+ before the query point are occluded, and these are removed from the
88
+ evaluation.
89
+ get_trackwise_metrics: if True, the metrics will be computed for every
90
+ track (rather than every video, which is the default). This means
91
+ every output tensor will have an extra axis [batch, num_tracks] rather
92
+ than simply (batch).
93
+
94
+ Returns:
95
+ A dict with the following keys:
96
+
97
+ occlusion_accuracy: Accuracy at predicting occlusion.
98
+ pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
99
+ predicted to be within the given pixel threshold, ignoring occlusion
100
+ prediction.
101
+ jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
102
+ threshold
103
+ average_pts_within_thresh: average across pts_within_{x}
104
+ average_jaccard: average across jaccard_{x}
105
+ """
106
+
107
+ summing_axis = (2,) if get_trackwise_metrics else (1, 2)
108
+
109
+ metrics = {}
110
+
111
+ eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
112
+ if query_mode == 'first':
113
+ # evaluate frames after the query frame
114
+ query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
115
+ elif query_mode == 'strided':
116
+ # evaluate all frames except the query frame
117
+ query_frame_to_eval_frames = 1 - eye
118
+ else:
119
+ raise ValueError('Unknown query mode ' + query_mode)
120
+
121
+ query_frame = query_points[..., 0]
122
+ query_frame = np.round(query_frame).astype(np.int32)
123
+ evaluation_points = query_frame_to_eval_frames[query_frame] > 0
124
+
125
+ # Occlusion accuracy is simply how often the predicted occlusion equals the
126
+ # ground truth.
127
+ occ_acc = np.sum(
128
+ np.equal(pred_occluded, gt_occluded) & evaluation_points,
129
+ axis=summing_axis,
130
+ ) / np.sum(evaluation_points, axis=summing_axis)
131
+ metrics['occlusion_accuracy'] = occ_acc
132
+
133
+ # Next, convert the predictions and ground truth positions into pixel
134
+ # coordinates.
135
+ visible = np.logical_not(gt_occluded)
136
+ pred_visible = np.logical_not(pred_occluded)
137
+ all_frac_within = []
138
+ all_jaccard = []
139
+ for thresh in [1, 2, 4, 8, 16]:
140
+ # True positives are points that are within the threshold and where both
141
+ # the prediction and the ground truth are listed as visible.
142
+ within_dist = np.sum(
143
+ np.square(pred_tracks - gt_tracks),
144
+ axis=-1,
145
+ ) < np.square(thresh)
146
+ is_correct = np.logical_and(within_dist, visible)
147
+
148
+ # Compute the frac_within_threshold, which is the fraction of points
149
+ # within the threshold among points that are visible in the ground truth,
150
+ # ignoring whether they're predicted to be visible.
151
+ count_correct = np.sum(
152
+ is_correct & evaluation_points,
153
+ axis=summing_axis,
154
+ )
155
+ count_visible_points = np.sum(
156
+ visible & evaluation_points, axis=summing_axis
157
+ )
158
+ frac_correct = count_correct / count_visible_points
159
+ metrics['pts_within_' + str(thresh)] = frac_correct
160
+ all_frac_within.append(frac_correct)
161
+
162
+ true_positives = np.sum(
163
+ is_correct & pred_visible & evaluation_points, axis=summing_axis
164
+ )
165
+
166
+ # The denominator of the jaccard metric is the true positives plus
167
+ # false positives plus false negatives. However, note that true positives
168
+ # plus false negatives is simply the number of points in the ground truth
169
+ # which is easier to compute than trying to compute all three quantities.
170
+ # Thus we just add the number of points in the ground truth to the number
171
+ # of false positives.
172
+ #
173
+ # False positives are simply points that are predicted to be visible,
174
+ # but the ground truth is not visible or too far from the prediction.
175
+ gt_positives = np.sum(visible & evaluation_points, axis=summing_axis)
176
+ false_positives = (~visible) & pred_visible
177
+ false_positives = false_positives | ((~within_dist) & pred_visible)
178
+ false_positives = np.sum(
179
+ false_positives & evaluation_points, axis=summing_axis
180
+ )
181
+ jaccard = true_positives / (gt_positives + false_positives)
182
+ metrics['jaccard_' + str(thresh)] = jaccard
183
+ all_jaccard.append(jaccard)
184
+ metrics['average_jaccard'] = np.mean(
185
+ np.stack(all_jaccard, axis=1),
186
+ axis=1,
187
+ )
188
+ metrics['average_pts_within_thresh'] = np.mean(
189
+ np.stack(all_frac_within, axis=1),
190
+ axis=1,
191
+ )
192
+ return metrics
193
+
194
+
195
+ def latex_table(mean_scalars: Mapping[str, float]) -> str:
196
+ """Generate a latex table for displaying TAP-Vid and PCK metrics."""
197
+ if 'average_jaccard' in mean_scalars:
198
+ latex_fields = [
199
+ 'average_jaccard',
200
+ 'average_pts_within_thresh',
201
+ 'occlusion_accuracy',
202
+ 'jaccard_1',
203
+ 'jaccard_2',
204
+ 'jaccard_4',
205
+ 'jaccard_8',
206
+ 'jaccard_16',
207
+ 'pts_within_1',
208
+ 'pts_within_2',
209
+ 'pts_within_4',
210
+ 'pts_within_8',
211
+ 'pts_within_16',
212
+ ]
213
+ header = (
214
+ 'AJ & $<\\delta^{x}_{avg}$ & OA & Jac. $\\delta^{0}$ & '
215
+ + 'Jac. $\\delta^{1}$ & Jac. $\\delta^{2}$ & '
216
+ + 'Jac. $\\delta^{3}$ & Jac. $\\delta^{4}$ & $<\\delta^{0}$ & '
217
+ + '$<\\delta^{1}$ & $<\\delta^{2}$ & $<\\delta^{3}$ & '
218
+ + '$<\\delta^{4}$'
219
+ )
220
+ else:
221
222
+ header = ' & '.join(latex_fields)
223
+
224
+ body = ' & '.join(
225
+ [f'{float(np.array(mean_scalars[x]*100)):.3}' for x in latex_fields]
226
+ )
227
+ return '\n'.join([header, body])
228
+
229
+
230
+ def sample_queries_strided(
231
+ target_occluded: np.ndarray,
232
+ target_points: np.ndarray,
233
+ frames: np.ndarray,
234
+ query_stride: int = 5,
235
+ ) -> Mapping[str, np.ndarray]:
236
+ """Package a set of frames and tracks for use in TAPNet evaluations.
237
+
238
+ Given a set of frames and tracks with no query points, sample queries
239
+ strided every query_stride frames, ignoring points that are not visible
240
+ at the selected frames.
241
+
242
+ Args:
243
+ target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
244
+ where True indicates occluded.
245
+ target_points: Position, of shape [n_tracks, n_frames, 2], where each point
246
+ is [x,y] scaled between 0 and 1.
247
+ frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
248
+ -1 and 1.
249
+ query_stride: When sampling query points, search for un-occluded points
250
+ every query_stride frames and convert each one into a query.
251
+
252
+ Returns:
253
+ A dict with the keys:
254
+ video: Video tensor of shape [1, n_frames, height, width, 3]. The video
255
+ has floats scaled to the range [-1, 1].
256
+ query_points: Query points of shape [1, n_queries, 3] where
257
+ each point is [t, y, x] scaled to the range [-1, 1].
258
+ target_points: Target points of shape [1, n_queries, n_frames, 2] where
259
+ each point is [x, y] scaled to the range [-1, 1].
260
+ trackgroup: Index of the original track that each query point was
261
+ sampled from. This is useful for visualization.
262
+ """
263
+ tracks = []
264
+ occs = []
265
+ queries = []
266
+ trackgroups = []
267
+ total = 0
268
+ trackgroup = np.arange(target_occluded.shape[0])
269
+ for i in range(0, target_occluded.shape[1], query_stride):
270
+ mask = target_occluded[:, i] == 0
271
+ query = np.stack(
272
+ [
273
+ i * np.ones(target_occluded.shape[0:1]),
274
+ target_points[:, i, 1],
275
+ target_points[:, i, 0],
276
+ ],
277
+ axis=-1,
278
+ )
279
+ queries.append(query[mask])
280
+ tracks.append(target_points[mask])
281
+ occs.append(target_occluded[mask])
282
+ trackgroups.append(trackgroup[mask])
283
+ total += np.array(np.sum(target_occluded[:, i] == 0))
284
+
285
+ return {
286
+ 'video': frames[np.newaxis, ...],
287
+ 'query_points': np.concatenate(queries, axis=0)[np.newaxis, ...],
288
+ 'target_points': np.concatenate(tracks, axis=0)[np.newaxis, ...],
289
+ 'occluded': np.concatenate(occs, axis=0)[np.newaxis, ...],
290
+ 'trackgroup': np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
291
+ }
292
+
293
+
294
+ def sample_queries_first(
295
+ target_occluded: np.ndarray,
296
+ target_points: np.ndarray,
297
+ frames: np.ndarray,
298
+ ) -> Mapping[str, np.ndarray]:
299
+ """Package a set of frames and tracks for use in TAPNet evaluations.
300
+
301
+ Given a set of frames and tracks with no query points, use the first
302
+ visible point in each track as the query.
303
+
304
+ Args:
305
+ target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
306
+ where True indicates occluded.
307
+ target_points: Position, of shape [n_tracks, n_frames, 2], where each point
308
+ is [x,y] scaled between 0 and 1.
309
+ frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
310
+ -1 and 1.
311
+
312
+ Returns:
313
+ A dict with the keys:
314
+ video: Video tensor of shape [1, n_frames, height, width, 3]
315
+ query_points: Query points of shape [1, n_queries, 3] where
316
+ each point is [t, y, x] scaled to the range [-1, 1]
317
+ target_points: Target points of shape [1, n_queries, n_frames, 2] where
318
+ each point is [x, y] scaled to the range [-1, 1]
319
+ """
320
+
321
+ valid = np.sum(~target_occluded, axis=1) > 0
322
+ target_points = target_points[valid, :]
323
+ target_occluded = target_occluded[valid, :]
324
+
325
+ query_points = []
326
+ for i in range(target_points.shape[0]):
327
+ index = np.where(target_occluded[i] == 0)[0][0]
328
+ x, y = target_points[i, index, 0], target_points[i, index, 1]
329
+ query_points.append(np.array([index, y, x])) # [t, y, x]
330
+ query_points = np.stack(query_points, axis=0)
331
+
332
+ return {
333
+ 'video': frames[np.newaxis, ...],
334
+ 'query_points': query_points[np.newaxis, ...],
335
+ 'target_points': target_points[np.newaxis, ...],
336
+ 'occluded': target_occluded[np.newaxis, ...],
337
+ }
338
+
339
+
340
+ def create_jhmdb_dataset(
341
+ jhmdb_path: str, resolution: Optional[Tuple[int, int]] = (256, 256)
342
+ ) -> Iterable[DatasetElement]:
343
+ """JHMDB dataset, including fields required for PCK evaluation."""
344
+ videos = []
345
+ for file in tf.io.gfile.listdir(path.join(gt_dir, 'splits')):
346
+ # JHMDB file containing the first split, which is standard for this type of
347
+ # evaluation.
348
+ if not file.endswith('split1.txt'):
349
+ continue
350
+
351
+ video_folder = '_'.join(file.split('_')[:-2])
352
+ for video in tf.io.gfile.GFile(path.join(gt_dir, 'splits', file), 'r'):
353
+ video, traintest = video.split()
354
+ video, _ = video.split('.')
355
+
356
+ traintest = int(traintest)
357
+ video_path = path.join(video_folder, video)
358
+
359
+ if traintest == 2:
360
+ videos.append(video_path)
361
+
362
+ if not videos:
363
+ raise ValueError('No JHMDB videos found in directory ' + str(jhmdb_path))
364
+
365
+ # Shuffle so numbers converge faster.
366
+ random.shuffle(videos)
367
+
368
+ for video in videos:
369
+ logging.info(video)
370
+ joints = path.join(gt_dir, 'joint_positions', video, 'joint_positions.mat')
371
+
372
+ if not tf.io.gfile.exists(joints):
373
+ logging.info('skip %s', video)
374
+ continue
375
+
376
+ gt_pose = sio.loadmat(tf.io.gfile.GFile(joints, 'rb'))['pos_img']
377
+ gt_pose = np.transpose(gt_pose, [1, 2, 0])
378
+ frames = path.join(gt_dir, 'Rename_Images', video, '*.png')
379
+ framefil = tf.io.gfile.glob(frames)
380
+ framefil.sort()
381
+
382
+ def read_frame(f):
383
+ im = Image.open(tf.io.gfile.GFile(f, 'rb'))
384
+ im = im.convert('RGB')
385
+ im_data = np.array(im.getdata(), np.uint8)
386
+ return im_data.reshape([im.size[1], im.size[0], 3])
387
+
388
+ frames = [read_frame(x) for x in framefil]
389
+ frames = np.stack(frames)
390
+ height = frames.shape[1]
391
+ width = frames.shape[2]
392
+ invalid_x = np.logical_or(
393
+ gt_pose[:, 0:1, 0] < 0,
394
+ gt_pose[:, 0:1, 0] >= width,
395
+ )
396
+ invalid_y = np.logical_or(
397
+ gt_pose[:, 0:1, 1] < 0,
398
+ gt_pose[:, 0:1, 1] >= height,
399
+ )
400
+ invalid = np.logical_or(invalid_x, invalid_y)
401
+ invalid = np.tile(invalid, [1, gt_pose.shape[1]])
402
+ invalid = invalid[:, :, np.newaxis].astype(np.float32)
403
+ gt_pose_orig = gt_pose
404
+
405
+ if resolution is not None and resolution != frames.shape[1:3]:
406
+ frames = resize_video(frames, resolution)
407
+ frames = frames / (255.0 / 2.0) - 1.0
408
+ queries = gt_pose[:, 0]
409
+ queries = np.concatenate(
410
+ [queries[..., 0:1] * 0, queries[..., ::-1]],
411
+ axis=-1,
412
+ )
413
+ gt_pose = convert_grid_coordinates(
414
+ gt_pose,
415
+ np.array([width, height]),
416
+ np.array([frames.shape[2], frames.shape[1]]),
417
+ )
418
+ # Set invalid poses to -1 (outside the frame)
419
+ gt_pose = (1.0 - invalid) * gt_pose + invalid * (-1.0)
420
+
421
+ if gt_pose.shape[1] < frames.shape[0]:
422
+ # Some videos have pose sequences that are shorter than the frame
423
+ # sequence (usually because the person disappears). In this case,
424
+ # truncate the video.
425
+ logging.warning('short video!!')
426
+ frames = frames[: gt_pose.shape[1]]
427
+
428
+ converted = {
429
+ 'video': frames[np.newaxis, ...],
430
+ 'query_points': queries[np.newaxis, ...],
431
+ 'target_points': gt_pose[np.newaxis, ...],
432
+ 'gt_pose': gt_pose[np.newaxis, ...],
433
+ 'gt_pose_orig': gt_pose_orig[np.newaxis, ...],
434
+ 'occluded': gt_pose[np.newaxis, ..., 0] * 0,
435
+ 'fname': video,
436
+ 'im_size': np.array([height, width]),
437
+ }
438
+ yield {'jhmdb': converted}
439
+
440
+
441
+ def create_kubric_eval_train_dataset(
442
+ mode: str,
443
+ train_size: Tuple[int, int] = (256, 256),
444
+ max_dataset_size: int = 100,
445
+ ) -> Iterable[DatasetElement]:
446
+ """Dataset for evaluating performance on Kubric training data."""
447
+
448
+ # Lazy import kubric because requirements_inference doesn't include it.
449
+ from kubric.challenges.point_tracking import dataset
450
+ res = dataset.create_point_tracking_dataset(
451
+ split='train',
452
+ train_size=train_size,
453
+ batch_dims=[1],
454
+ shuffle_buffer_size=None,
455
+ repeat=False,
456
+ vflip='vflip' in mode,
457
+ random_crop=False,
458
+ )
459
+ np_ds = tfds.as_numpy(res)
460
+
461
+ num_returned = 0
462
+ for data in np_ds:
463
+ if num_returned >= max_dataset_size:
464
+ break
465
+ num_returned += 1
466
+ yield {'kubric': data}
467
+
468
+
469
+ def create_kubric_eval_dataset(
470
+ mode: str, train_size: Tuple[int, int] = (256, 256)
471
+ ) -> Iterable[DatasetElement]:
472
+ """Dataset for evaluating performance on Kubric val data."""
473
+ # Lazy import kubric because requirements_inference doesn't include it.
474
+ from kubric.challenges.point_tracking import dataset
475
+ res = dataset.create_point_tracking_dataset(
476
+ split='validation',
477
+ train_size=train_size,
478
+ batch_dims=[1],
479
+ shuffle_buffer_size=None,
480
+ repeat=False,
481
+ vflip='vflip' in mode,
482
+ random_crop=False,
483
+ )
484
+ np_ds = tfds.as_numpy(res)
485
+
486
+ for data in np_ds:
487
+ yield {'kubric': data}
488
+
489
+
490
+ def create_davis_dataset(
491
+ davis_points_path: str,
492
+ query_mode: str = 'strided',
493
+ full_resolution=False,
494
+ resolution: Optional[Tuple[int, int]] = (256, 256),
495
+ ) -> Iterable[DatasetElement]:
496
+ """Dataset for evaluating performance on DAVIS data."""
497
+ pickle_path = davis_points_path
498
+
499
+ with tf.io.gfile.GFile(pickle_path, 'rb') as f:
500
+ davis_points_dataset = pickle.load(f)
501
+
502
+ if full_resolution:
503
+ ds, _ = tfds.load(
504
+ 'davis/full_resolution', split='validation', with_info=True
505
+ )
506
+ to_iterate = tfds.as_numpy(ds)
507
+ else:
508
+ to_iterate = davis_points_dataset.keys()
509
+
510
+ for tmp in to_iterate:
511
+ if full_resolution:
512
+ frames = tmp['video']['frames']
513
+ video_name = tmp['metadata']['video_name'].decode()
514
+ else:
515
+ video_name = tmp
516
+ frames = davis_points_dataset[video_name]['video']
517
+ if resolution is not None and resolution != frames.shape[1:3]:
518
+ frames = resize_video(frames, resolution)
519
+
520
+ frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
521
+ target_points = davis_points_dataset[video_name]['points']
522
+ target_occ = davis_points_dataset[video_name]['occluded']
523
+ target_points = target_points * np.array([frames.shape[2], frames.shape[1]])
524
+
525
+ if query_mode == 'strided':
526
+ converted = sample_queries_strided(target_occ, target_points, frames)
527
+ elif query_mode == 'first':
528
+ converted = sample_queries_first(target_occ, target_points, frames)
529
+ else:
530
+ raise ValueError(f'Unknown query mode {query_mode}.')
531
+
532
+ yield {'davis': converted}
533
+
534
+
535
+ def create_rgb_stacking_dataset(
536
+ robotics_points_path: str,
537
+ query_mode: str = 'strided',
538
+ resolution: Optional[Tuple[int, int]] = (256, 256),
539
+ ) -> Iterable[DatasetElement]:
540
+ """Dataset for evaluating performance on robotics data."""
541
+ pickle_path = robotics_points_path
542
+
543
+ with tf.io.gfile.GFile(pickle_path, 'rb') as f:
544
+ robotics_points_dataset = pickle.load(f)
545
+
546
+ for example in robotics_points_dataset:
547
+ frames = example['video']
548
+ if resolution is not None and resolution != frames.shape[1:3]:
549
+ frames = resize_video(frames, resolution)
550
+ frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
551
+ target_points = example['points']
552
+ target_occ = example['occluded']
553
+ target_points = target_points * np.array([frames.shape[2], frames.shape[1]])
554
+
555
+ if query_mode == 'strided':
556
+ converted = sample_queries_strided(target_occ, target_points, frames)
557
+ elif query_mode == 'first':
558
+ converted = sample_queries_first(target_occ, target_points, frames)
559
+ else:
560
+ raise ValueError(f'Unknown query mode {query_mode}.')
561
+
562
+ yield {'robotics': converted}
563
+
564
+
565
+ def create_kinetics_dataset(
566
+ kinetics_path: str, query_mode: str = 'strided',
567
+ resolution: Optional[Tuple[int, int]] = (256, 256),
568
+ ) -> Iterable[DatasetElement]:
569
+ """Dataset for evaluating performance on Kinetics point tracking."""
570
+
571
+ all_paths = tf.io.gfile.glob(path.join(kinetics_path, '*_of_0010.pkl'))
572
+ for pickle_path in all_paths:
573
+ with open(pickle_path, 'rb') as f:
574
+ data = pickle.load(f)
575
+ if isinstance(data, dict):
576
+ data = list(data.values())
577
+
578
+ # idx = random.randint(0, len(data) - 1)
579
+ for idx in range(len(data)):
580
+ example = data[idx]
581
+
582
+ frames = example['video']
583
+
584
+ if isinstance(frames[0], bytes):
585
+ # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
586
+ def decode(frame):
587
+ byteio = io.BytesIO(frame)
588
+ img = Image.open(byteio)
589
+ return np.array(img)
590
+
591
+ frames = np.array([decode(frame) for frame in frames])
592
+
593
+ if resolution is not None and resolution != frames.shape[1:3]:
594
+ frames = resize_video(frames, resolution)
595
+
596
+ frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
597
+ target_points = example['points']
598
+ target_occ = example['occluded']
599
+ target_points *= np.array([frames.shape[2], frames.shape[1]])
600
+
601
+ if query_mode == 'strided':
602
+ converted = sample_queries_strided(target_occ, target_points, frames)
603
+ elif query_mode == 'first':
604
+ converted = sample_queries_first(target_occ, target_points, frames)
605
+ else:
606
+ raise ValueError(f'Unknown query mode {query_mode}.')
607
+
608
+ yield {'kinetics': converted}
609
+
610
+
611
+ def create_robotap_dataset(
612
+ robotics_points_path: str,
613
+ query_mode: str = 'strided',
614
+ resolution: Optional[Tuple[int, int]] = (256, 256),
615
+ ) -> Iterable[DatasetElement]:
616
+ """Dataset for evaluating performance on robotics data."""
617
+ pickle_path = robotics_points_path
618
+
619
+ # with tf.io.gfile.GFile(pickle_path, 'rb') as f:
620
+ # robotics_points_dataset = pickle.load(f)
621
+ robotics_points_dataset = []
622
+ all_paths = tf.io.gfile.glob(path.join(robotics_points_path, '*.pkl'))
623
+ for pickle_path in all_paths:
624
+ with open(pickle_path, 'rb') as f:
625
+ data = pickle.load(f)
626
+ robotics_points_dataset.extend(data.values())
627
+
628
+ for example in robotics_points_dataset:
629
+ frames = example['video']
630
+ if resolution is not None and resolution != frames.shape[1:3]:
631
+ frames = resize_video(frames, resolution)
632
+ frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
633
+ target_points = example['points']
634
+ target_occ = example['occluded']
635
+ target_points = target_points * np.array([frames.shape[2], frames.shape[1]])
636
+
637
+ if query_mode == 'strided':
638
+ converted = sample_queries_strided(target_occ, target_points, frames)
639
+ elif query_mode == 'first':
640
+ converted = sample_queries_first(target_occ, target_points, frames)
641
+ else:
642
+ raise ValueError(f'Unknown query mode {query_mode}.')
643
+
644
+ yield {'robotap': converted}
645
+
646
+
647
+ def create_csv_dataset(
648
+ dataset_name: str,
649
+ csv_path: str,
650
+ video_base_path: str,
651
+ query_mode: str = 'strided',
652
+ resolution: Optional[Tuple[int, int]] = (256, 256),
653
+ max_video_frames: Optional[int] = 1000,
654
+ ) -> Iterable[DatasetElement]:
655
+ """Create an evaluation iterator out of human annotations and videos.
656
+
657
+ Args:
658
+ dataset_name: Name to the dataset.
659
+ csv_path: Path to annotations csv.
660
+ video_base_path: Path to annotated videos.
661
+ query_mode: sample query points from first frame or strided.
662
+ resolution: The video resolution in (height, width).
663
+ max_video_frames: Max length of annotated video.
664
+
665
+ Yields:
666
+ Samples for evaluation.
667
+ """
668
+ point_tracks_all = dict()
669
+ with tf.io.gfile.GFile(csv_path, 'r') as f:
670
+ reader = csv.reader(f, delimiter=',')
671
+ for row in reader:
672
+ video_id = row[0]
673
+ point_tracks = np.array(row[1:]).reshape(-1, 3)
674
+ if video_id in point_tracks_all:
675
+ point_tracks_all[video_id].append(point_tracks)
676
+ else:
677
+ point_tracks_all[video_id] = [point_tracks]
678
+
679
+ for video_id in point_tracks_all:
680
+ if video_id.endswith('.mp4'):
681
+ video_path = path.join(video_base_path, video_id)
682
+ else:
683
+ video_path = path.join(video_base_path, video_id + '.mp4')
684
+ frames = media.read_video(video_path)
685
+ if resolution is not None and resolution != frames.shape[1:3]:
686
+ frames = media.resize_video(frames, resolution)
687
+ frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
688
+
689
+ point_tracks = np.stack(point_tracks_all[video_id], axis=0)
690
+ point_tracks = point_tracks.astype(np.float32)
691
+ if frames.shape[0] < point_tracks.shape[1]:
692
+ logging.info('Warning: short video!')
693
+ point_tracks = point_tracks[:, : frames.shape[0]]
694
+ point_tracks, occluded = point_tracks[..., 0:2], point_tracks[..., 2]
695
+ occluded = occluded > 0
696
+ target_points = point_tracks * np.array([frames.shape[2], frames.shape[1]])
697
+
698
+ num_splits = int(np.ceil(frames.shape[0] / max_video_frames))
699
+ if num_splits > 1:
700
+ print(f'Going to split the video {video_id} into {num_splits}')
701
+ for i in range(num_splits):
702
+ start_index = i * frames.shape[0] // num_splits
703
+ end_index = (i + 1) * frames.shape[0] // num_splits
704
+ sub_occluded = occluded[:, start_index:end_index]
705
+ sub_target_points = target_points[:, start_index:end_index]
706
+ sub_frames = frames[start_index:end_index]
707
+
708
+ if query_mode == 'strided':
709
+ converted = sample_queries_strided(
710
+ sub_occluded, sub_target_points, sub_frames
711
+ )
712
+ elif query_mode == 'first':
713
+ converted = sample_queries_first(
714
+ sub_occluded, sub_target_points, sub_frames
715
+ )
716
+ else:
717
+ raise ValueError(f'Unknown query mode {query_mode}.')
718
+
719
+ yield {dataset_name: converted}
720
+
721
+
722
+ import torch
723
+ from torch.utils.data import Dataset
724
+
725
+
726
+ class CustomDataset(Dataset):
727
+ def __init__(self, data_generator: Iterable[DatasetElement], key: str):
728
+ self.data = list(data_generator)
729
+ self.key = key
730
+
731
+ def __len__(self):
732
+ return len(self.data)
733
+
734
+ def __getitem__(self, idx):
735
+ data = self.data[idx][self.key]
736
+ data = {k: torch.tensor(v)[0] if isinstance(v, np.ndarray) else v for k, v in data.items()}
737
+ # Convert double to float
738
+ data = {k: v.float() if v.dtype == torch.float64 else v for k, v in data.items()}
739
+ return data
740
+
741
+
742
+ def get_eval_dataset(mode, path, resolution=(256, 256)):
743
+ query_mode = 'first' if 'q_first' in mode else 'strided'
744
+ datasets = {}
745
+ if 'jhmdb' in mode:
746
+ key = 'jhmdb'
747
+ dataset = create_jhmdb_dataset(path[key], resolution)
748
+ datasets[key] = CustomDataset(dataset, key)
749
+ if 'davis' in mode:
750
+ key = 'davis'
751
+ dataset = create_davis_dataset(path[key], query_mode, False, resolution=resolution)
752
+ datasets[key] = CustomDataset(dataset, key)
753
+ if 'robotics' in mode:
754
+ key = 'robotics'
755
+ dataset = create_rgb_stacking_dataset(path[key], query_mode, resolution)
756
+ datasets[key] = CustomDataset(dataset, key)
757
+ if 'kinetics' in mode:
758
+ key = 'kinetics'
759
+ dataset = create_kinetics_dataset(path[key], query_mode, resolution)
760
+ datasets[key] = CustomDataset(dataset, key)
761
+ if 'robotap' in mode:
762
+ key = 'robotap'
763
+ dataset = create_robotap_dataset(path[key], query_mode, resolution)
764
+ datasets[key] = CustomDataset(dataset, key)
765
+
766
+ if len(datasets) == 0:
767
+ raise ValueError(f'No dataset found for mode {mode}.')
768
+
769
+ return datasets
770
+
771
+
772
+ if __name__ == '__main__':
773
+ # Disable all GPUS
774
+ tf.config.set_visible_devices([], 'GPU')
775
+ visible_devices = tf.config.get_visible_devices()
776
+ for device in visible_devices:
777
+ assert device.device_type != 'GPU'
778
+
779
+ dataset_name = 'davis'
780
+ dataset_path = '/media/data2/PointTracking/tapvid/tapnet_dataset/tapvid_davis/tapvid_davis.pkl'
781
+
782
+ dataset = get_eval_dataset(dataset_name, dataset_path, 'strided', (256, 256))
783
+ breakpoint()
784
+ pass
locotrack_pytorch/data/kubric_data.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ import functools
7
+ import tensorflow_datasets as tfds
8
+ import tensorflow as tf
9
+ import torch.distributed
10
+ from kubric.challenges.point_tracking.dataset import add_tracks
11
+
12
+
13
+ # Disable all GPUS
14
+ tf.config.set_visible_devices([], 'GPU')
15
+ visible_devices = tf.config.get_visible_devices()
16
+ for device in visible_devices:
17
+ assert device.device_type != 'GPU'
18
+
19
+
20
+ def default_color_augmentation_fn(
21
+ inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
22
+ """Standard color augmentation for videos.
23
+
24
+ Args:
25
+ inputs: A DatasetElement containing the item 'video' which will have
26
+ augmentations applied to it.
27
+
28
+ Returns:
29
+ A DatasetElement with all the same data as the original, except that
30
+ the video has augmentations applied.
31
+ """
32
+ zero_centering_image = True
33
+ prob_color_augment = 0.8
34
+ prob_color_drop = 0.2
35
+
36
+ frames = inputs['video']
37
+ if frames.dtype != tf.float32:
38
+ raise ValueError('`frames` should be in float32.')
39
+
40
+ def color_augment(video: tf.Tensor) -> tf.Tensor:
41
+ """Do standard color augmentations."""
42
+ # Note the same augmentation will be applied to all frames of the video.
43
+ if zero_centering_image:
44
+ video = 0.5 * (video + 1.0)
45
+ video = tf.image.random_brightness(video, max_delta=32. / 255.)
46
+ video = tf.image.random_saturation(video, lower=0.6, upper=1.4)
47
+ video = tf.image.random_contrast(video, lower=0.6, upper=1.4)
48
+ video = tf.image.random_hue(video, max_delta=0.2)
49
+ video = tf.clip_by_value(video, 0.0, 1.0)
50
+ if zero_centering_image:
51
+ video = 2 * (video-0.5)
52
+ return video
53
+
54
+ def color_drop(video: tf.Tensor) -> tf.Tensor:
55
+ video = tf.image.rgb_to_grayscale(video)
56
+ video = tf.tile(video, [1, 1, 1, 1, 3])
57
+ return video
58
+
59
+ # Eventually applies color augmentation.
60
+ coin_toss_color_augment = tf.random.uniform(
61
+ [], minval=0, maxval=1, dtype=tf.float32)
62
+ frames = tf.cond(
63
+ pred=tf.less(coin_toss_color_augment,
64
+ tf.cast(prob_color_augment, tf.float32)),
65
+ true_fn=lambda: color_augment(frames),
66
+ false_fn=lambda: frames)
67
+
68
+ # Eventually applies color drop.
69
+ coin_toss_color_drop = tf.random.uniform(
70
+ [], minval=0, maxval=1, dtype=tf.float32)
71
+ frames = tf.cond(
72
+ pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)),
73
+ true_fn=lambda: color_drop(frames),
74
+ false_fn=lambda: frames)
75
+ result = {**inputs}
76
+ result['video'] = frames
77
+
78
+ return result
79
+
80
+
81
+ def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset:
82
+ return ds.map(
83
+ default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE)
84
+
85
+
86
+ def create_point_tracking_dataset(
87
+ data_dir=None,
88
+ color_augmentation=True,
89
+ train_size=(256, 256),
90
+ shuffle_buffer_size=256,
91
+ split='train',
92
+ # batch_dims=tuple(),
93
+ batch_size=1,
94
+ repeat=True,
95
+ vflip=False,
96
+ random_crop=True,
97
+ tracks_to_sample=256,
98
+ sampling_stride=4,
99
+ max_seg_id=40,
100
+ max_sampled_frac=0.1,
101
+ num_parallel_point_extraction_calls=16,
102
+ **kwargs):
103
+ """Construct a dataset for point tracking using Kubric.
104
+
105
+ Args:
106
+ train_size: Tuple of 2 ints. Cropped output will be at this resolution
107
+ shuffle_buffer_size: Int. Size of the shuffle buffer
108
+ split: Which split to construct from Kubric. Can be 'train' or
109
+ 'validation'.
110
+ batch_dims: Sequence of ints. Add multiple examples into a batch of this
111
+ shape.
112
+ repeat: Bool. whether to repeat the dataset.
113
+ vflip: Bool. whether to vertically flip the dataset to test generalization.
114
+ random_crop: Bool. whether to randomly crop videos
115
+ tracks_to_sample: Int. Total number of tracks to sample per video.
116
+ sampling_stride: Int. For efficiency, query points are sampled from a
117
+ random grid of this stride.
118
+ max_seg_id: Int. The maxium segment id in the video. Note the size of
119
+ the to graph is proportional to this number, so prefer small values.
120
+ max_sampled_frac: Float. The maximum fraction of points to sample from each
121
+ object, out of all points that lie on the sampling grid.
122
+ num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
123
+ map function for point extraction.
124
+ snap_to_occluder: If true, query points within 1 pixel of occlusion
125
+ boundaries will track the occluding surface rather than the background.
126
+ This results in models which are biased to track foreground objects
127
+ instead of background. Whether this is desirable depends on downstream
128
+ applications.
129
+ **kwargs: additional args to pass to tfds.load.
130
+
131
+ Returns:
132
+ The dataset generator.
133
+ """
134
+ ds = tfds.load(
135
+ 'panning_movi_e/256x256',
136
+ data_dir=data_dir,
137
+ shuffle_files=shuffle_buffer_size is not None,
138
+ **kwargs)
139
+
140
+ ds = ds[split]
141
+ if repeat:
142
+ ds = ds.repeat()
143
+ ds = ds.map(
144
+ functools.partial(
145
+ add_tracks,
146
+ train_size=train_size,
147
+ vflip=vflip,
148
+ random_crop=random_crop,
149
+ tracks_to_sample=tracks_to_sample,
150
+ sampling_stride=sampling_stride,
151
+ max_seg_id=max_seg_id,
152
+ max_sampled_frac=max_sampled_frac),
153
+ num_parallel_calls=num_parallel_point_extraction_calls)
154
+ if shuffle_buffer_size is not None:
155
+ ds = ds.shuffle(shuffle_buffer_size)
156
+
157
+ ds = ds.batch(batch_size)
158
+
159
+ if color_augmentation:
160
+ ds = add_default_data_augmentation(ds)
161
+ ds = tfds.as_numpy(ds)
162
+
163
+ it = iter(ds)
164
+ while True:
165
+ data = next(it)
166
+ yield data
167
+
168
+
169
+ class KubricData:
170
+ def __init__(
171
+ self,
172
+ global_rank,
173
+ data_dir,
174
+ **kwargs
175
+ ):
176
+ self.global_rank = global_rank
177
+
178
+ if self.global_rank == 0:
179
+ self.data = create_point_tracking_dataset(
180
+ data_dir=data_dir,
181
+ **kwargs
182
+ )
183
+
184
+ def __getitem__(self, idx):
185
+ if self.global_rank == 0:
186
+ batch_all = next(self.data)
187
+ batch_list = []
188
+
189
+ world_size = torch.distributed.get_world_size()
190
+ batch_size = batch_all['video'].shape[0] // world_size
191
+
192
+
193
+ for i in range(world_size):
194
+ batch = {}
195
+ for k, v in batch_all.items():
196
+ if isinstance(v, (np.ndarray, torch.Tensor)):
197
+ batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size])
198
+ batch_list.append(batch)
199
+ else:
200
+ batch_list = [None] * torch.distributed.get_world_size()
201
+
202
+
203
+ batch = [None]
204
+ torch.distributed.scatter_object_list(batch, batch_list, src=0)
205
+
206
+ return batch[0]
207
+
208
+
209
+ if __name__ == '__main__':
210
+
211
+ import torch.nn as nn
212
+ import lightning as L
213
+ from lightning.pytorch.strategies import DDPStrategy
214
+
215
+ class Model(L.LightningModule):
216
+ def __init__(self):
217
+ super().__init__()
218
+ self.model = nn.Linear(256 * 256 * 3 * 24, 1)
219
+
220
+ def forward(self, x):
221
+ return self.model(x)
222
+
223
+ def training_step(self, batch, batch_idx):
224
+ breakpoint()
225
+ x = batch['video']
226
+ x = x.reshape(x.shape[0], -1)
227
+ y = self(x)
228
+ return y
229
+
230
+ def configure_optimizers(self):
231
+ return torch.optim.Adam(self.parameters(), lr=1e-3)
232
+
233
+ model = Model()
234
+
235
+ trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1)
236
+
237
+ dataloader = KubricData(
238
+ global_rank=trainer.global_rank,
239
+ data_dir='/media/data2/PointTracking/tensorflow_datasets',
240
+ batch_size=1 * trainer.world_size,
241
+ )
242
+
243
+ trainer.fit(model, dataloader)
locotrack_pytorch/environment.yml ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: locotrack-pytorch
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_6
8
+ - ca-certificates=2024.7.2=h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_1
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgomp=11.2.0=h1234567_1
13
+ - libstdcxx-ng=11.2.0=h1234567_1
14
+ - libuuid=1.41.5=h5eee18b_0
15
+ - ncurses=6.4=h6a678d5_0
16
+ - openssl=3.0.14=h5eee18b_0
17
+ - pip=24.0=py311h06a4308_0
18
+ - python=3.11.9=h955ad1f_0
19
+ - readline=8.2=h5eee18b_0
20
+ - setuptools=69.5.1=py311h06a4308_0
21
+ - sqlite=3.45.3=h5eee18b_0
22
+ - tk=8.6.14=h39e8969_0
23
+ - tzdata=2024a=h04d1e81_0
24
+ - wheel=0.43.0=py311h06a4308_0
25
+ - xz=5.4.6=h5eee18b_1
26
+ - zlib=1.2.13=h5eee18b_1
27
+ - pip:
28
+ - absl-py==2.1.0
29
+ - aiohttp==3.9.5
30
+ - aiosignal==1.3.1
31
+ - array-record==0.5.1
32
+ - asttokens==2.4.1
33
+ - astunparse==1.6.3
34
+ - attrs==23.2.0
35
+ - certifi==2024.7.4
36
+ - charset-normalizer==3.3.2
37
+ - click==8.1.7
38
+ - contourpy==1.2.1
39
+ - cycler==0.12.1
40
+ - decorator==5.1.1
41
+ - dm-tree==0.1.8
42
+ - docker-pycreds==0.4.0
43
+ - docstring-parser==0.16
44
+ - einshape==1.0
45
+ - etils==1.9.2
46
+ - executing==2.0.1
47
+ - filelock==3.13.1
48
+ - flatbuffers==24.3.25
49
+ - fonttools==4.53.1
50
+ - frozenlist==1.4.1
51
+ - fsspec==2024.2.0
52
+ - gast==0.6.0
53
+ - gitdb==4.0.11
54
+ - gitpython==3.1.43
55
+ - google-pasta==0.2.0
56
+ - googleapis-common-protos==1.63.2
57
+ - grpcio==1.65.1
58
+ - h5py==3.11.0
59
+ - idna==3.7
60
+ - immutabledict==4.2.0
61
+ - importlib-resources==6.4.0
62
+ - ipython==8.26.0
63
+ - jedi==0.19.1
64
+ - jinja2==3.1.3
65
+ - keras==3.4.1
66
+ - kiwisolver==1.4.5
67
+ - libclang==18.1.1
68
+ - lightning==2.3.3
69
+ - lightning-utilities==0.11.6
70
+ - markdown==3.6
71
+ - markdown-it-py==3.0.0
72
+ - markupsafe==2.1.5
73
+ - matplotlib==3.9.1
74
+ - matplotlib-inline==0.1.7
75
+ - mdurl==0.1.2
76
+ - mediapy==1.2.2
77
+ - ml-dtypes==0.4.0
78
+ - mpmath==1.3.0
79
+ - multidict==6.0.5
80
+ - namex==0.0.8
81
+ - networkx==3.2.1
82
+ - numpy==1.26.3
83
+ - nvidia-cublas-cu12==12.4.2.65
84
+ - nvidia-cuda-cupti-cu12==12.4.99
85
+ - nvidia-cuda-nvrtc-cu12==12.4.99
86
+ - nvidia-cuda-runtime-cu12==12.4.99
87
+ - nvidia-cudnn-cu12==9.1.0.70
88
+ - nvidia-cufft-cu12==11.2.0.44
89
+ - nvidia-curand-cu12==10.3.5.119
90
+ - nvidia-cusolver-cu12==11.6.0.99
91
+ - nvidia-cusparse-cu12==12.3.0.142
92
+ - nvidia-nccl-cu12==2.20.5
93
+ - nvidia-nvjitlink-cu12==12.4.99
94
+ - nvidia-nvtx-cu12==12.4.99
95
+ - openexr==3.2.4
96
+ - opt-einsum==3.3.0
97
+ - optree==0.12.1
98
+ - packaging==24.1
99
+ - parso==0.8.4
100
+ - pexpect==4.9.0
101
+ - pillow==10.2.0
102
+ - platformdirs==4.2.2
103
+ - promise==2.3
104
+ - prompt-toolkit==3.0.47
105
+ - protobuf==4.25.4
106
+ - psutil==6.0.0
107
+ - ptyprocess==0.7.0
108
+ - pure-eval==0.2.3
109
+ - pyarrow==17.0.0
110
+ - pygments==2.18.0
111
+ - pyparsing==3.1.2
112
+ - python-dateutil==2.9.0.post0
113
+ - pytorch-lightning==2.3.3
114
+ - pyyaml==6.0.1
115
+ - requests==2.32.3
116
+ - rich==13.7.1
117
+ - scipy==1.14.0
118
+ - sentry-sdk==2.11.0
119
+ - setproctitle==1.3.3
120
+ - simple-parsing==0.1.5
121
+ - six==1.16.0
122
+ - smmap==5.0.1
123
+ - stack-data==0.6.3
124
+ - sympy==1.12
125
+ - tensorboard==2.17.0
126
+ - tensorboard-data-server==0.7.2
127
+ - tensorflow==2.17.0
128
+ - tensorflow-addons==0.23.0
129
+ - tensorflow-datasets==4.9.6
130
+ - tensorflow-graphics==2021.12.3
131
+ - tensorflow-io-gcs-filesystem==0.37.1
132
+ - tensorflow-metadata==1.15.0
133
+ - termcolor==2.4.0
134
+ - toml==0.10.2
135
+ - torch==2.4.0+cu124
136
+ - torchaudio==2.4.0+cu124
137
+ - torchmetrics==1.4.0.post0
138
+ - torchvision==0.19.0+cu124
139
+ - tqdm==4.66.4
140
+ - traitlets==5.14.3
141
+ - trimesh==4.4.3
142
+ - triton==3.0.0
143
+ - typeguard==2.13.3
144
+ - typing-extensions==4.9.0
145
+ - urllib3==2.2.2
146
+ - wandb==0.17.5
147
+ - wcwidth==0.2.13
148
+ - werkzeug==3.0.3
149
+ - wrapt==1.16.0
150
+ - yarl==1.9.4
151
+ - zipp==3.19.2
locotrack_pytorch/experiment.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import configparser
3
+ import argparse
4
+ import logging
5
+ from functools import partial
6
+ from typing import Any, Dict, Optional, Union
7
+
8
+ import lightning as L
9
+ from lightning.pytorch import seed_everything
10
+ from lightning.pytorch.loggers import WandbLogger
11
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ from data.kubric_data import KubricData
16
+ from models.locotrack_model import LocoTrack
17
+ import model_utils
18
+ from data.evaluation_datasets import get_eval_dataset
19
+
20
+
21
+ class LocoTrackModel(L.LightningModule):
22
+ def __init__(
23
+ self,
24
+ model_kwargs: Optional[Dict[str, Any]] = None,
25
+ model_forward_kwargs: Optional[Dict[str, Any]] = None,
26
+ loss_name: Optional[str] = 'tapir_loss',
27
+ loss_kwargs: Optional[Dict[str, Any]] = None,
28
+ query_first: Optional[bool] = False,
29
+ optimizer_name: Optional[str] = 'Adam',
30
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
31
+ scheduler_name: Optional[str] = 'OneCycleLR',
32
+ scheduler_kwargs: Optional[Dict[str, Any]] = None,
33
+ ):
34
+ super().__init__()
35
+ self.model = LocoTrack(**(model_kwargs or {}))
36
+ self.model_forward_kwargs = model_forward_kwargs or {}
37
+ self.loss = partial(model_utils.__dict__[loss_name], **(loss_kwargs or {}))
38
+ self.query_first = query_first
39
+
40
+ self.optimizer_name = optimizer_name
41
+ self.optimizer_kwargs = optimizer_kwargs or {'lr': 2e-3}
42
+ self.scheduler_name = scheduler_name
43
+ self.scheduler_kwargs = scheduler_kwargs or {'max_lr': 2e-3, 'pct_start': 0.05, 'total_steps': 300000}
44
+
45
+ def training_step(self, batch, batch_idx):
46
+ output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs)
47
+ loss, loss_scalars = self.loss(batch, output)
48
+
49
+ self.log_dict(
50
+ {f'train/{k}': v.item() for k, v in loss_scalars.items()},
51
+ logger=True,
52
+ on_step=True,
53
+ sync_dist=True,
54
+ )
55
+
56
+ return loss
57
+
58
+ def validation_step(self, batch, batch_idx, dataloader_idx=None):
59
+ output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs)
60
+ loss, loss_scalars = self.loss(batch, output)
61
+ metrics = model_utils.eval_batch(batch, output, query_first=self.query_first)
62
+
63
+ if self.trainer.global_rank == 0:
64
+ log_prefix = 'val/'
65
+ if dataloader_idx is not None:
66
+ log_prefix = f'val/data_{dataloader_idx}/'
67
+
68
+ self.log_dict(
69
+ {log_prefix + k: v for k, v in loss_scalars.items()},
70
+ logger=True,
71
+ rank_zero_only=True,
72
+ )
73
+ self.log_dict(
74
+ {log_prefix + k: v.item() for k, v in metrics.items()},
75
+ logger=True,
76
+ rank_zero_only=True,
77
+ )
78
+ logging.info(f"Batch {batch_idx}: {metrics}")
79
+
80
+ def test_step(self, batch, batch_idx, dataloader_idx=None):
81
+ output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs)
82
+ loss, loss_scalars = self.loss(batch, output)
83
+ metrics = model_utils.eval_batch(batch, output, query_first=self.query_first)
84
+
85
+ if self.trainer.global_rank == 0:
86
+ log_prefix = 'test/'
87
+ if dataloader_idx is not None:
88
+ log_prefix = f'test/data_{dataloader_idx}/'
89
+
90
+ self.log_dict(
91
+ {log_prefix + k: v for k, v in loss_scalars.items()},
92
+ logger=True,
93
+ rank_zero_only=True,
94
+ )
95
+ self.log_dict(
96
+ {log_prefix + k: v.item() for k, v in metrics.items()},
97
+ logger=True,
98
+ rank_zero_only=True,
99
+ )
100
+ logging.info(f"Batch {batch_idx}: {metrics}")
101
+
102
+ def configure_optimizers(self):
103
+ weights = [p for n, p in self.named_parameters() if 'bias' not in n]
104
+ bias = [p for n, p in self.named_parameters() if 'bias' in n]
105
+
106
+ optimizer = torch.optim.__dict__[self.optimizer_name](
107
+ [
108
+ {'params': weights, **self.optimizer_kwargs},
109
+ {'params': bias, **self.optimizer_kwargs, 'weight_decay': 0.}
110
+ ]
111
+ )
112
+ scheduler = torch.optim.lr_scheduler.__dict__[self.scheduler_name](optimizer, **self.scheduler_kwargs)
113
+
114
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
115
+
116
+
117
+ def train(
118
+ mode: str,
119
+ save_path: str,
120
+ val_dataset_path: str,
121
+ ckpt_path: str = None,
122
+ kubric_dir: str = '',
123
+ precision: str = '32',
124
+ batch_size: int = 1,
125
+ val_check_interval: Union[int, float] = 5000,
126
+ log_every_n_steps: int = 10,
127
+ gradient_clip_val: float = 1.0,
128
+ max_steps: int = 300_000,
129
+ model_kwargs: Optional[Dict[str, Any]] = None,
130
+ model_forward_kwargs: Optional[Dict[str, Any]] = None,
131
+ loss_name: str = 'tapir_loss',
132
+ loss_kwargs: Optional[Dict[str, Any]] = None,
133
+ optimizer_name: str = 'Adam',
134
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
135
+ scheduler_name: str = 'OneCycleLR',
136
+ scheduler_kwargs: Optional[Dict[str, Any]] = None,
137
+ # query_first: bool = False,
138
+ ):
139
+ """Train the LocoTrack model with specified configurations."""
140
+ seed_everything(42, workers=True)
141
+
142
+ model = LocoTrackModel(
143
+ model_kwargs=model_kwargs,
144
+ model_forward_kwargs=model_forward_kwargs,
145
+ loss_name=loss_name,
146
+ loss_kwargs=loss_kwargs,
147
+ query_first='q_first' in mode,
148
+ optimizer_name=optimizer_name,
149
+ optimizer_kwargs=optimizer_kwargs,
150
+ scheduler_name=scheduler_name,
151
+ scheduler_kwargs=scheduler_kwargs,
152
+ )
153
+ if ckpt_path is not None and 'train' in mode:
154
+ model.load_state_dict(torch.load(ckpt_path)['state_dict'])
155
+
156
+ logger = WandbLogger(project='LocoTrack_Pytorch', save_dir=save_path, id=os.path.basename(save_path))
157
+ lr_monitor = LearningRateMonitor(logging_interval='step')
158
+ checkpoint_callback = ModelCheckpoint(
159
+ dirpath=save_path,
160
+ save_last=True,
161
+ save_top_k=3,
162
+ mode="max",
163
+ monitor="val/average_pts_within_thresh",
164
+ auto_insert_metric_name=True,
165
+ save_on_train_epoch_end=False,
166
+ )
167
+
168
+ eval_dataset = get_eval_dataset(
169
+ mode=mode,
170
+ path=val_dataset_path,
171
+ )
172
+ eval_dataloder = {
173
+ k: DataLoader(
174
+ v,
175
+ batch_size=1,
176
+ shuffle=False,
177
+ ) for k, v in eval_dataset.items()
178
+ }
179
+
180
+ if 'train' in mode:
181
+ trainer = L.Trainer(
182
+ strategy='ddp',
183
+ logger=logger,
184
+ precision=precision,
185
+ val_check_interval=val_check_interval,
186
+ log_every_n_steps=log_every_n_steps,
187
+ gradient_clip_val=gradient_clip_val,
188
+ max_steps=max_steps,
189
+ sync_batchnorm=True,
190
+ callbacks=[checkpoint_callback, lr_monitor],
191
+ )
192
+ train_dataloader = KubricData(
193
+ global_rank=trainer.global_rank,
194
+ data_dir=kubric_dir,
195
+ batch_size=batch_size * trainer.world_size,
196
+ )
197
+ trainer.fit(model, train_dataloader, eval_dataloder, ckpt_path=ckpt_path)
198
+ elif 'eval' in mode:
199
+ trainer = L.Trainer(strategy='ddp', logger=logger, precision=precision)
200
+ trainer.test(model, eval_dataloder, ckpt_path=ckpt_path)
201
+ else:
202
+ raise ValueError(f"Invalid mode: {mode}")
203
+
204
+ if __name__ == '__main__':
205
+ parser = argparse.ArgumentParser(description="Train or evaluate the LocoTrack model.")
206
+ parser.add_argument('--config', type=str, default='config.ini', help="Path to the configuration file.")
207
+ parser.add_argument('--mode', type=str, required=True, help="Mode to run: 'train' or 'eval' with optional 'q_first' and the name of evaluation dataset.")
208
+ parser.add_argument('--ckpt_path', type=str, default=None, help="Path to the checkpoint file")
209
+ parser.add_argument('--save_path', type=str, default='snapshots', help="Path to save the logs and checkpoints.")
210
+
211
+ args = parser.parse_args()
212
+ config = configparser.ConfigParser()
213
+ config.read(args.config)
214
+
215
+ # Extract parameters from the config file
216
+ train_params = {
217
+ 'mode': args.mode,
218
+ 'ckpt_path': args.ckpt_path,
219
+ 'save_path': args.save_path,
220
+ 'val_dataset_path': eval(config.get('TRAINING', 'val_dataset_path', fallback='{}')),
221
+ 'kubric_dir': config.get('TRAINING', 'kubric_dir', fallback=''),
222
+ 'precision': config.get('TRAINING', 'precision', fallback='32'),
223
+ 'batch_size': config.getint('TRAINING', 'batch_size', fallback=1),
224
+ 'val_check_interval': config.getfloat('TRAINING', 'val_check_interval', fallback=5000),
225
+ 'log_every_n_steps': config.getint('TRAINING', 'log_every_n_steps', fallback=10),
226
+ 'gradient_clip_val': config.getfloat('TRAINING', 'gradient_clip_val', fallback=1.0),
227
+ 'max_steps': config.getint('TRAINING', 'max_steps', fallback=300000),
228
+ 'model_kwargs': eval(config.get('MODEL', 'model_kwargs', fallback='{}')),
229
+ 'model_forward_kwargs': eval(config.get('MODEL', 'model_forward_kwargs', fallback='{}')),
230
+ 'loss_name': config.get('LOSS', 'loss_name', fallback='tapir_loss'),
231
+ 'loss_kwargs': eval(config.get('LOSS', 'loss_kwargs', fallback='{}')),
232
+ 'optimizer_name': config.get('OPTIMIZER', 'optimizer_name', fallback='Adam'),
233
+ 'optimizer_kwargs': eval(config.get('OPTIMIZER', 'optimizer_kwargs', fallback='{"lr": 2e-3}')),
234
+ 'scheduler_name': config.get('SCHEDULER', 'scheduler_name', fallback='OneCycleLR'),
235
+ 'scheduler_kwargs': eval(config.get('SCHEDULER', 'scheduler_kwargs', fallback='{"max_lr": 2e-3, "pct_start": 0.05, "total_steps": 300000}')),
236
+ }
237
+
238
+ train(**train_params)
locotrack_pytorch/model_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from models.utils import convert_grid_coordinates
7
+ from data.evaluation_datasets import compute_tapvid_metrics
8
+
9
+ def huber_loss(tracks, target_points, occluded, delta=4.0, reduction_axes=(1, 2)):
10
+ """Huber loss for point trajectories."""
11
+ error = tracks - target_points
12
+ distsqr = torch.sum(error ** 2, dim=-1)
13
+ dist = torch.sqrt(distsqr + 1e-12) # add eps to prevent nan
14
+ loss_huber = torch.where(dist < delta, distsqr / 2, delta * (torch.abs(dist) - delta / 2))
15
+ loss_huber = loss_huber * (1.0 - occluded.float())
16
+
17
+ if reduction_axes:
18
+ loss_huber = torch.mean(loss_huber, dim=reduction_axes)
19
+
20
+ return loss_huber
21
+
22
+ def prob_loss(tracks, expd, target_points, occluded, expected_dist_thresh=8.0, reduction_axes=(1, 2)):
23
+ """Loss for classifying if a point is within pixel threshold of its target."""
24
+ err = torch.sum((tracks - target_points) ** 2, dim=-1)
25
+ invalid = (err > expected_dist_thresh ** 2).float()
26
+ logprob = F.binary_cross_entropy_with_logits(expd, invalid, reduction='none')
27
+ logprob = logprob * (1.0 - occluded.float())
28
+
29
+ if reduction_axes:
30
+ logprob = torch.mean(logprob, dim=reduction_axes)
31
+
32
+ return logprob
33
+
34
+ def tapnet_loss(points, occlusion, target_points, target_occ, shape, mask=None, expected_dist=None,
35
+ position_loss_weight=0.05, expected_dist_thresh=6.0, huber_loss_delta=4.0,
36
+ rebalance_factor=None, occlusion_loss_mask=None):
37
+ """TAPNet loss."""
38
+
39
+ if mask is None:
40
+ mask = torch.tensor(1.0)
41
+
42
+ points = convert_grid_coordinates(points, shape[3:1:-1], (256, 256), coordinate_format='xy')
43
+ target_points = convert_grid_coordinates(target_points, shape[3:1:-1], (256, 256), coordinate_format='xy')
44
+
45
+ loss_huber = huber_loss(points, target_points, target_occ, delta=huber_loss_delta, reduction_axes=None) * mask
46
+ loss_huber = torch.mean(loss_huber) * position_loss_weight
47
+
48
+ if expected_dist is None:
49
+ loss_prob = torch.tensor(0.0)
50
+ else:
51
+ loss_prob = prob_loss(points.detach(), expected_dist, target_points, target_occ, expected_dist_thresh, reduction_axes=None) * mask
52
+ loss_prob = torch.mean(loss_prob)
53
+
54
+ target_occ = target_occ.to(dtype=occlusion.dtype)
55
+ loss_occ = F.binary_cross_entropy_with_logits(occlusion, target_occ, reduction='none') * mask
56
+
57
+ if rebalance_factor is not None:
58
+ loss_occ = loss_occ * ((1 + rebalance_factor) - rebalance_factor * target_occ)
59
+
60
+ if occlusion_loss_mask is not None:
61
+ loss_occ = loss_occ * occlusion_loss_mask
62
+
63
+ loss_occ = torch.mean(loss_occ)
64
+
65
+ return loss_huber, loss_occ, loss_prob
66
+
67
+
68
+ def tapir_loss(
69
+ batch,
70
+ output,
71
+ position_loss_weight=0.05,
72
+ expected_dist_thresh=6.0,
73
+ ):
74
+ loss_scalars = {}
75
+ loss_huber, loss_occ, loss_prob = tapnet_loss(
76
+ output['tracks'],
77
+ output['occlusion'],
78
+ batch['target_points'],
79
+ batch['occluded'],
80
+ batch['video'].shape, # pytype: disable=attribute-error # numpy-scalars
81
+ expected_dist=output['expected_dist']
82
+ if 'expected_dist' in output
83
+ else None,
84
+ position_loss_weight=position_loss_weight,
85
+ expected_dist_thresh=expected_dist_thresh,
86
+ )
87
+ loss = loss_huber + loss_occ + loss_prob
88
+ loss_scalars['position_loss'] = loss_huber
89
+ loss_scalars['occlusion_loss'] = loss_occ
90
+ if 'expected_dist' in output:
91
+ loss_scalars['prob_loss'] = loss_prob
92
+
93
+ if 'unrefined_tracks' in output:
94
+ for l in range(len(output['unrefined_tracks'])):
95
+ loss_huber, loss_occ, loss_prob = tapnet_loss(
96
+ output['unrefined_tracks'][l],
97
+ output['unrefined_occlusion'][l],
98
+ batch['target_points'],
99
+ batch['occluded'],
100
+ batch['video'].shape, # pytype: disable=attribute-error # numpy-scalars
101
+ expected_dist=output['unrefined_expected_dist'][l]
102
+ if 'unrefined_expected_dist' in output
103
+ else None,
104
+ position_loss_weight=position_loss_weight,
105
+ expected_dist_thresh=expected_dist_thresh,
106
+ )
107
+ loss = loss + loss_huber + loss_occ + loss_prob
108
+ loss_scalars[f'position_loss_{l}'] = loss_huber
109
+ loss_scalars[f'occlusion_loss_{l}'] = loss_occ
110
+ if 'unrefined_expected_dist' in output:
111
+ loss_scalars[f'prob_loss_{l}'] = loss_prob
112
+
113
+ loss_scalars['loss'] = loss
114
+ return loss, loss_scalars
115
+
116
+
117
+
118
+ def eval_batch(
119
+ batch,
120
+ output,
121
+ eval_metrics_resolution = (256, 256),
122
+ query_first = False,
123
+ ):
124
+ query_points = batch['query_points']
125
+ query_points = convert_grid_coordinates(
126
+ query_points,
127
+ (1,) + batch['video'].shape[2:4], # (1, height, width)
128
+ (1,) + eval_metrics_resolution, # (1, height, width)
129
+ coordinate_format='tyx',
130
+ )
131
+ gt_target_points = batch['target_points']
132
+ gt_target_points = convert_grid_coordinates(
133
+ gt_target_points,
134
+ batch['video'].shape[3:1:-1], # (width, height)
135
+ eval_metrics_resolution[::-1], # (width, height)
136
+ coordinate_format='xy',
137
+ )
138
+ gt_occluded = batch['occluded']
139
+
140
+ tracks = output['tracks']
141
+ tracks = convert_grid_coordinates(
142
+ tracks,
143
+ batch['video'].shape[3:1:-1], # (width, height)
144
+ eval_metrics_resolution[::-1], # (width, height)
145
+ coordinate_format='xy',
146
+ )
147
+
148
+ occlusion_logits = output['occlusion']
149
+ pred_occ = torch.sigmoid(occlusion_logits)
150
+ if 'expected_dist' in output:
151
+ expected_dist = output['expected_dist']
152
+ pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(expected_dist))
153
+ pred_occ = pred_occ > 0.5 # threshold
154
+
155
+ query_mode = 'first' if query_first else 'strided'
156
+ metrics = compute_tapvid_metrics(
157
+ query_points=query_points.detach().cpu().numpy(),
158
+ gt_occluded=gt_occluded.detach().cpu().numpy(),
159
+ gt_tracks=gt_target_points.detach().cpu().numpy(),
160
+ pred_occluded=pred_occ.detach().cpu().numpy(),
161
+ pred_tracks=tracks.detach().cpu().numpy(),
162
+ query_mode=query_mode,
163
+ )
164
+
165
+ return metrics
locotrack_pytorch/models/cmdtop.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models import utils
5
+
6
+
7
+ class CMDTop(nn.Module):
8
+ def __init__(self, in_channel, out_channels, kernel_shapes, strides):
9
+ super(CMDTop, self).__init__()
10
+ self.in_channels = [in_channel] + list(out_channels[:-1])
11
+ self.out_channels = out_channels
12
+ self.kernel_shapes = kernel_shapes
13
+ self.strides = strides
14
+
15
+ self.conv = nn.ModuleList([
16
+ nn.Sequential(
17
+ utils.Conv2dSamePadding(
18
+ in_channels=self.in_channels[i],
19
+ out_channels=self.out_channels[i],
20
+ kernel_size=self.kernel_shapes[i],
21
+ stride=self.strides[i],
22
+ ),
23
+ nn.GroupNorm(out_channels[i] // 16, out_channels[i]),
24
+ nn.ReLU()
25
+ ) for i in range(len(out_channels))
26
+ ])
27
+
28
+ def forward(self, x):
29
+ """
30
+ x: (b, h, w, i, j)
31
+ """
32
+ out1 = utils.einshape('bhwij->b(ij)hw', x)
33
+ out2 = utils.einshape('bhwij->b(hw)ij', x)
34
+
35
+ for i in range(len(self.out_channels)):
36
+ out1 = self.conv[i](out1)
37
+
38
+ for i in range(len(self.out_channels)):
39
+ out2 = self.conv[i](out2)
40
+
41
+ out1 = torch.mean(out1, dim=(2, 3)) # (b, out_channels[-1])
42
+ out2 = torch.mean(out2, dim=(2, 3)) # (b, out_channels[-1])
43
+
44
+ return torch.cat([out1, out2], dim=-1) # (b, 2*out_channels[-1])
45
+
locotrack_pytorch/models/locotrack_model.py ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """TAPIR models definition."""
17
+
18
+ import functools
19
+ from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple
20
+
21
+ import torch
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ from models import nets, utils
27
+ from models.cmdtop import CMDTop
28
+
29
+
30
+ def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
31
+ """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
32
+
33
+ Instead of computing [sin(x), cos(x)], we use the trig identity
34
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
35
+
36
+ Args:
37
+ x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
38
+ min_deg: int, the minimum (inclusive) degree of the encoding.
39
+ max_deg: int, the maximum (exclusive) degree of the encoding.
40
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
41
+
42
+ Returns:
43
+ encoded: torch.Tensor, encoded variables.
44
+ """
45
+ if min_deg == max_deg:
46
+ return x
47
+ scales = torch.tensor([2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device)
48
+ if legacy_posenc_order:
49
+ xb = x[..., None, :] * scales[:, None]
50
+ four_feat = torch.reshape(
51
+ torch.sin(torch.stack([xb, xb + 0.5 * np.pi], dim=-2)),
52
+ list(x.shape[:-1]) + [-1]
53
+ )
54
+ else:
55
+ xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
56
+ four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
57
+ return torch.cat([x] + [four_feat], dim=-1)
58
+
59
+
60
+ def get_relative_positions(seq_len, reverse=False):
61
+ x = torch.arange(seq_len)[None, :]
62
+ y = torch.arange(seq_len)[:, None]
63
+ return torch.tril(x - y) if not reverse else torch.triu(y - x)
64
+
65
+
66
+ def get_alibi_slope(num_heads):
67
+ x = (24) ** (1 / num_heads)
68
+ return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], dtype=torch.float32).view(-1, 1, 1)
69
+
70
+
71
+ class MultiHeadAttention(nn.Module):
72
+ """Multi-headed attention (MHA) module."""
73
+
74
+ def __init__(self, num_heads, key_size, w_init_scale=None, w_init=None, with_bias=True, b_init=None, value_size=None, model_size=None):
75
+ super(MultiHeadAttention, self).__init__()
76
+ self.num_heads = num_heads
77
+ self.key_size = key_size
78
+ self.value_size = value_size or key_size
79
+ self.model_size = model_size or key_size * num_heads
80
+
81
+ self.with_bias = with_bias
82
+
83
+ self.query_proj = nn.Linear(num_heads * key_size, num_heads * key_size, bias=with_bias)
84
+ self.key_proj = nn.Linear(num_heads * key_size, num_heads * key_size, bias=with_bias)
85
+ self.value_proj = nn.Linear(num_heads * self.value_size, num_heads * self.value_size, bias=with_bias)
86
+ self.final_proj = nn.Linear(num_heads * self.value_size, self.model_size, bias=with_bias)
87
+
88
+ def forward(self, query, key, value, mask=None):
89
+ batch_size, sequence_length, _ = query.size()
90
+
91
+ query_heads = self._linear_projection(query, self.key_size, self.query_proj) # [T', H, Q=K]
92
+ key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
93
+ value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
94
+
95
+ bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length)
96
+ bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
97
+ bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length, reverse=True)
98
+ bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
99
+ attn_bias = torch.cat([bias_forward, bias_backward], dim=0).to(query.device)
100
+
101
+ attn_logits = torch.einsum("...thd,...Thd->...htT", query_heads, key_heads)
102
+ attn_logits = attn_logits / np.sqrt(self.key_size) + attn_bias
103
+
104
+ if mask is not None:
105
+ if mask.ndim != attn_logits.ndim:
106
+ raise ValueError(f"Mask dimensionality {mask.ndim} must match logits dimensionality {attn_logits.ndim}.")
107
+ attn_logits = torch.where(mask, attn_logits, torch.tensor(-1e30))
108
+
109
+ attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
110
+
111
+ attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
112
+ attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
113
+
114
+ return self.final_proj(attn) # [T', D']
115
+
116
+ def _linear_projection(self, x, head_size, proj_layer):
117
+ y = proj_layer(x)
118
+ *leading_dims, _ = x.shape
119
+ return y.reshape((*leading_dims, self.num_heads, head_size))
120
+
121
+
122
+ class Transformer(nn.Module):
123
+ """A transformer stack."""
124
+
125
+ def __init__(self, num_heads, num_layers, attn_size, dropout_rate, widening_factor=4):
126
+ super(Transformer, self).__init__()
127
+ self.num_heads = num_heads
128
+ self.num_layers = num_layers
129
+ self.attn_size = attn_size
130
+ self.dropout_rate = dropout_rate
131
+ self.widening_factor = widening_factor
132
+
133
+ self.layers = nn.ModuleList([
134
+ nn.ModuleDict({
135
+ 'attn': MultiHeadAttention(num_heads, attn_size, model_size=attn_size * num_heads),
136
+ 'dense': nn.Sequential(
137
+ nn.Linear(attn_size * num_heads, widening_factor * attn_size * num_heads),
138
+ nn.GELU(),
139
+ nn.Linear(widening_factor * attn_size * num_heads, attn_size * num_heads)
140
+ ),
141
+ 'layer_norm1': nn.LayerNorm(attn_size * num_heads),
142
+ 'layer_norm2': nn.LayerNorm(attn_size * num_heads)
143
+ })
144
+ for _ in range(num_layers)
145
+ ])
146
+
147
+ self.ln_out = nn.LayerNorm(attn_size * num_heads)
148
+
149
+ def forward(self, embeddings, mask=None):
150
+ h = embeddings
151
+ for layer in self.layers:
152
+ h_norm = layer['layer_norm1'](h)
153
+ h_attn = layer['attn'](h_norm, h_norm, h_norm, mask=mask)
154
+ h_attn = F.dropout(h_attn, p=self.dropout_rate, training=self.training)
155
+ h = h + h_attn
156
+
157
+ h_norm = layer['layer_norm2'](h)
158
+ h_dense = layer['dense'](h_norm)
159
+ h_dense = F.dropout(h_dense, p=self.dropout_rate, training=self.training)
160
+ h = h + h_dense
161
+
162
+ return self.ln_out(h)
163
+
164
+
165
+ class PIPSTransformer(nn.Module):
166
+ def __init__(self, input_channels, output_channels, dim=512, num_heads=8, num_layers=1):
167
+ super(PIPSTransformer, self).__init__()
168
+ self.dim = dim
169
+
170
+ self.transformer = Transformer(
171
+ num_heads=num_heads,
172
+ num_layers=num_layers,
173
+ attn_size=dim // num_heads,
174
+ dropout_rate=0.,
175
+ widening_factor=4,
176
+ )
177
+ self.input_proj = nn.Linear(input_channels, dim)
178
+ self.output_proj = nn.Linear(dim, output_channels)
179
+
180
+ def forward(self, x):
181
+ x = self.input_proj(x)
182
+ x = self.transformer(x, mask=None)
183
+ return self.output_proj(x)
184
+
185
+
186
+ class FeatureGrids(NamedTuple):
187
+ """Feature grids for a video, used to compute trajectories.
188
+
189
+ These are per-frame outputs of the encoding resnet.
190
+
191
+ Attributes:
192
+ lowres: Low-resolution features, one for each resolution; 256 channels.
193
+ hires: High-resolution features, one for each resolution; 64 channels.
194
+ resolutions: Resolutions used for trajectory computation. There will be one
195
+ entry for the initialization, and then an entry for each PIPs refinement
196
+ resolution.
197
+ """
198
+
199
+ lowres: Sequence[torch.Tensor]
200
+ hires: Sequence[torch.Tensor]
201
+ highest: Sequence[torch.Tensor]
202
+ resolutions: Sequence[Tuple[int, int]]
203
+
204
+
205
+ class QueryFeatures(NamedTuple):
206
+ """Query features used to compute trajectories.
207
+
208
+ These are sampled from the query frames and are a full descriptor of the
209
+ tracked points. They can be acquired from a query image and then reused in a
210
+ separate video.
211
+
212
+ Attributes:
213
+ lowres: Low-resolution features, one for each resolution; each has shape
214
+ [batch, num_query_points, 256]
215
+ hires: High-resolution features, one for each resolution; each has shape
216
+ [batch, num_query_points, 64]
217
+ resolutions: Resolutions used for trajectory computation. There will be one
218
+ entry for the initialization, and then an entry for each PIPs refinement
219
+ resolution.
220
+ """
221
+
222
+ lowres: Sequence[torch.Tensor]
223
+ hires: Sequence[torch.Tensor]
224
+ highest: Sequence[torch.Tensor]
225
+ lowres_supp: Sequence[torch.Tensor]
226
+ hires_supp: Sequence[torch.Tensor]
227
+ highest_supp: Sequence[torch.Tensor]
228
+ resolutions: Sequence[Tuple[int, int]]
229
+
230
+
231
+ class LocoTrack(nn.Module):
232
+ """TAPIR model."""
233
+
234
+ def __init__(
235
+ self,
236
+ bilinear_interp_with_depthwise_conv: bool = False,
237
+ num_pips_iter: int = 4,
238
+ pyramid_level: int = 0,
239
+ mixer_hidden_dim: int = 512,
240
+ num_mixer_blocks: int = 12,
241
+ mixer_kernel_shape: int = 3,
242
+ patch_size: int = 7,
243
+ softmax_temperature: float = 20.0,
244
+ parallelize_query_extraction: bool = False,
245
+ initial_resolution: Tuple[int, int] = (256, 256),
246
+ blocks_per_group: Sequence[int] = (2, 2, 2, 2),
247
+ feature_extractor_chunk_size: int = 256,
248
+ extra_convs: bool = False,
249
+ use_casual_conv: bool = False,
250
+ model_size: str = 'base',
251
+ ):
252
+ super().__init__()
253
+
254
+ if model_size == 'small':
255
+ model_params = {
256
+ 'dim': 256,
257
+ 'num_heads': 4,
258
+ 'num_layers': 3,
259
+ }
260
+ cmdtop_params = {
261
+ 'in_channel': 49,
262
+ 'out_channels': (64, 128),
263
+ 'kernel_shapes': (5, 2),
264
+ 'strides': (4, 2),
265
+ }
266
+ elif model_size == 'base':
267
+ model_params = {
268
+ 'dim': 384,
269
+ 'num_heads': 6,
270
+ 'num_layers': 3,
271
+ }
272
+ cmdtop_params = {
273
+ 'in_channel': 49,
274
+ 'out_channels': (64, 128, 128),
275
+ 'kernel_shapes': (3, 3, 2),
276
+ 'strides': (2, 2, 2),
277
+ }
278
+ else:
279
+ raise ValueError(f"Unknown model size '{model_size}'")
280
+
281
+ self.highres_dim = 128
282
+ self.lowres_dim = 256
283
+ self.bilinear_interp_with_depthwise_conv = (
284
+ bilinear_interp_with_depthwise_conv
285
+ )
286
+ self.parallelize_query_extraction = parallelize_query_extraction
287
+
288
+ self.num_pips_iter = num_pips_iter
289
+ self.pyramid_level = pyramid_level
290
+ self.patch_size = patch_size
291
+ self.softmax_temperature = softmax_temperature
292
+ self.initial_resolution = tuple(initial_resolution)
293
+ self.feature_extractor_chunk_size = feature_extractor_chunk_size
294
+ self.num_mixer_blocks = num_mixer_blocks
295
+ self.use_casual_conv = use_casual_conv
296
+
297
+ highres_dim = 128
298
+ lowres_dim = 256
299
+ strides = (1, 2, 2, 1)
300
+ blocks_per_group = (2, 2, 2, 2)
301
+ channels_per_group = (64, highres_dim, 256, lowres_dim)
302
+ use_projection = (True, True, True, True)
303
+
304
+ self.resnet_torch = nets.ResNet(
305
+ blocks_per_group=blocks_per_group,
306
+ channels_per_group=channels_per_group,
307
+ use_projection=use_projection,
308
+ strides=strides,
309
+ )
310
+
311
+ self.torch_pips_mixer = PIPSTransformer(
312
+ input_channels=854,
313
+ output_channels=4 + self.highres_dim + self.lowres_dim,
314
+ **model_params
315
+ )
316
+
317
+ self.cmdtop = nn.ModuleList([
318
+ CMDTop(
319
+ **cmdtop_params
320
+ ) for _ in range(3)
321
+ ])
322
+
323
+ self.cost_conv = utils.Conv2dSamePadding(2, 1, 3, 1)
324
+ self.occ_linear = nn.Linear(6, 2)
325
+
326
+ if extra_convs:
327
+ self.extra_convs = nets.ExtraConvs()
328
+ else:
329
+ self.extra_convs = None
330
+
331
+ def forward(
332
+ self,
333
+ video: torch.Tensor,
334
+ query_points: torch.Tensor,
335
+ feature_grids: Optional[FeatureGrids] = None,
336
+ is_training: bool = False,
337
+ query_chunk_size: Optional[int] = 64,
338
+ get_query_feats: bool = False,
339
+ refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
340
+ ) -> Mapping[str, torch.Tensor]:
341
+ """Runs a forward pass of the model.
342
+
343
+ Args:
344
+ video: A 5-D tensor representing a batch of sequences of images.
345
+ query_points: The query points for which we compute tracks.
346
+ is_training: Whether we are training.
347
+ query_chunk_size: When computing cost volumes, break the queries into
348
+ chunks of this size to save memory.
349
+ get_query_feats: Return query features for other losses like contrastive.
350
+ Not supported in the current version.
351
+ refinement_resolutions: A list of (height, width) tuples. Refinement will
352
+ be repeated at each specified resolution, in order to achieve high
353
+ accuracy on resolutions higher than what TAPIR was trained on. If None,
354
+ reasonable refinement resolutions will be inferred from the input video
355
+ size.
356
+
357
+ Returns:
358
+ A dict of outputs, including:
359
+ occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
360
+ where higher indicates more likely to be occluded.
361
+ tracks: predicted point locations, of shape
362
+ [batch, num_queries, num_frames, 2], where each point is [x, y]
363
+ in raster coordinates
364
+ expected_dist: uncertainty estimate logits, of shape
365
+ [batch, num_queries, num_frames], where higher indicates more likely
366
+ to be far from the correct answer.
367
+ """
368
+ if get_query_feats:
369
+ raise ValueError('Get query feats not supported in TAPIR.')
370
+
371
+ if feature_grids is None:
372
+ feature_grids = self.get_feature_grids(
373
+ video,
374
+ is_training,
375
+ refinement_resolutions,
376
+ )
377
+
378
+ query_features = self.get_query_features(
379
+ video,
380
+ is_training,
381
+ query_points,
382
+ feature_grids,
383
+ refinement_resolutions,
384
+ )
385
+
386
+ trajectories = self.estimate_trajectories(
387
+ video.shape[-3:-1],
388
+ is_training,
389
+ feature_grids,
390
+ query_features,
391
+ query_points,
392
+ query_chunk_size,
393
+ )
394
+
395
+ p = self.num_pips_iter
396
+ out = dict(
397
+ occlusion=torch.mean(
398
+ torch.stack(trajectories['occlusion'][p::p]), dim=0
399
+ ),
400
+ tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
401
+ expected_dist=torch.mean(
402
+ torch.stack(trajectories['expected_dist'][p::p]), dim=0
403
+ ),
404
+ unrefined_occlusion=trajectories['occlusion'][:-1],
405
+ unrefined_tracks=trajectories['tracks'][:-1],
406
+ unrefined_expected_dist=trajectories['expected_dist'][:-1],
407
+ )
408
+
409
+ return out
410
+
411
+ def get_query_features(
412
+ self,
413
+ video: torch.Tensor,
414
+ is_training: bool,
415
+ query_points: torch.Tensor,
416
+ feature_grids: Optional[FeatureGrids] = None,
417
+ refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
418
+ ) -> QueryFeatures:
419
+ """Computes query features, which can be used for estimate_trajectories.
420
+
421
+ Args:
422
+ video: A 5-D tensor representing a batch of sequences of images.
423
+ is_training: Whether we are training.
424
+ query_points: The query points for which we compute tracks.
425
+ feature_grids: If passed, we'll use these feature grids rather than
426
+ computing new ones.
427
+ refinement_resolutions: A list of (height, width) tuples. Refinement will
428
+ be repeated at each specified resolution, in order to achieve high
429
+ accuracy on resolutions higher than what TAPIR was trained on. If None,
430
+ reasonable refinement resolutions will be inferred from the input video
431
+ size.
432
+
433
+ Returns:
434
+ A QueryFeatures object which contains the required features for every
435
+ required resolution.
436
+ """
437
+
438
+ if feature_grids is None:
439
+ feature_grids = self.get_feature_grids(
440
+ video,
441
+ is_training=is_training,
442
+ refinement_resolutions=refinement_resolutions,
443
+ )
444
+
445
+ feature_grid = feature_grids.lowres
446
+ hires_feats = feature_grids.hires
447
+ highest_feats = feature_grids.highest
448
+ resize_im_shape = feature_grids.resolutions
449
+
450
+ shape = video.shape
451
+ # shape is [batch_size, time, height, width, channels]; conversion needs
452
+ # [time, width, height]
453
+ curr_resolution = (-1, -1)
454
+ query_feats = []
455
+ hires_query_feats = []
456
+ highest_query_feats = []
457
+ query_supp = []
458
+ hires_query_supp = []
459
+ highest_query_supp = []
460
+ for i, resolution in enumerate(resize_im_shape):
461
+ if utils.is_same_res(curr_resolution, resolution):
462
+ query_feats.append(query_feats[-1])
463
+ hires_query_feats.append(hires_query_feats[-1])
464
+ highest_query_feats.append(highest_query_feats[-1])
465
+ query_supp.append(query_supp[-1])
466
+ hires_query_supp.append(hires_query_supp[-1])
467
+ highest_query_supp.append(highest_query_supp[-1])
468
+ continue
469
+ position_in_grid = utils.convert_grid_coordinates(
470
+ query_points,
471
+ shape[1:4],
472
+ feature_grid[i].shape[1:4],
473
+ coordinate_format='tyx',
474
+ )
475
+ position_in_grid_hires = utils.convert_grid_coordinates(
476
+ query_points,
477
+ shape[1:4],
478
+ hires_feats[i].shape[1:4],
479
+ coordinate_format='tyx',
480
+ )
481
+ position_in_grid_highest = utils.convert_grid_coordinates(
482
+ query_points,
483
+ shape[1:4],
484
+ highest_feats[i].shape[1:4],
485
+ coordinate_format='tyx',
486
+ )
487
+
488
+ support_size = 7
489
+ ctxx, ctxy = torch.meshgrid(
490
+ torch.arange(-(support_size // 2), support_size // 2 + 1),
491
+ torch.arange(-(support_size // 2), support_size // 2 + 1),
492
+ indexing='xy',
493
+ )
494
+ ctx = torch.stack([torch.zeros_like(ctxy), ctxy, ctxx], axis=-1)
495
+ ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
496
+
497
+ position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
498
+ position_support = utils.einshape('bnsc->b(ns)c', position_support)
499
+ interp_supp = utils.map_coordinates_3d(
500
+ feature_grid[i], position_support
501
+ )
502
+ interp_supp = utils.einshape('b(nhw)c->bnhwc', interp_supp, h=support_size, w=support_size)
503
+
504
+ position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
505
+ position_support_hires = utils.einshape('bnsc->b(ns)c', position_support_hires)
506
+ hires_interp_supp = utils.map_coordinates_3d(
507
+ hires_feats[i], position_support_hires
508
+ )
509
+ hires_interp_supp = utils.einshape('b(nhw)c->bnhwc', hires_interp_supp, h=support_size, w=support_size)
510
+
511
+ position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
512
+ position_support_highest = utils.einshape('bnsc->b(ns)c', position_support_highest)
513
+ highest_interp_supp = utils.map_coordinates_3d(
514
+ highest_feats[i], position_support_highest
515
+ )
516
+ highest_interp_supp = utils.einshape('b(nhw)c->bnhwc', highest_interp_supp, h=support_size, w=support_size)
517
+
518
+ interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
519
+ hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
520
+ highest_interp = highest_interp_supp[..., support_size // 2, support_size // 2, :]
521
+
522
+ hires_query_feats.append(hires_interp)
523
+ query_feats.append(interp_features)
524
+ highest_query_feats.append(highest_interp)
525
+ query_supp.append(interp_supp)
526
+ hires_query_supp.append(hires_interp_supp)
527
+ highest_query_supp.append(highest_interp_supp)
528
+
529
+ return QueryFeatures(
530
+ tuple(query_feats), tuple(hires_query_feats), tuple(highest_query_feats),
531
+ tuple(query_supp), tuple(hires_query_supp), tuple(highest_query_supp), tuple(resize_im_shape),
532
+ )
533
+
534
+ def get_feature_grids(
535
+ self,
536
+ video: torch.Tensor,
537
+ is_training: Optional[bool] = False,
538
+ refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
539
+ ) -> FeatureGrids:
540
+ """Computes feature grids.
541
+
542
+ Args:
543
+ video: A 5-D tensor representing a batch of sequences of images.
544
+ is_training: Whether we are training.
545
+ refinement_resolutions: A list of (height, width) tuples. Refinement will
546
+ be repeated at each specified resolution, to achieve high accuracy on
547
+ resolutions higher than what TAPIR was trained on. If None, reasonable
548
+ refinement resolutions will be inferred from the input video size.
549
+
550
+ Returns:
551
+ A FeatureGrids object containing the required features for every
552
+ required resolution. Note that there will be one more feature grid
553
+ than there are refinement_resolutions, because there is always a
554
+ feature grid computed for TAP-Net initialization.
555
+ """
556
+ del is_training
557
+ if refinement_resolutions is None:
558
+ refinement_resolutions = utils.generate_default_resolutions(
559
+ video.shape[2:4], self.initial_resolution
560
+ )
561
+
562
+ all_required_resolutions = [self.initial_resolution]
563
+ all_required_resolutions.extend(refinement_resolutions)
564
+
565
+ feature_grid = []
566
+ hires_feats = []
567
+ highest_feats = []
568
+ resize_im_shape = []
569
+ curr_resolution = (-1, -1)
570
+
571
+ latent = None
572
+ hires = None
573
+ video_resize = None
574
+ for resolution in all_required_resolutions:
575
+ if resolution[0] % 8 != 0 or resolution[1] % 8 != 0:
576
+ raise ValueError('Image resolution must be a multiple of 8.')
577
+
578
+ if not utils.is_same_res(curr_resolution, resolution):
579
+ if utils.is_same_res(curr_resolution, video.shape[-3:-1]):
580
+ video_resize = video
581
+ else:
582
+ video_resize = utils.bilinear(video, resolution)
583
+
584
+ curr_resolution = resolution
585
+ n, f, h, w, c = video_resize.shape
586
+ video_resize = video_resize.view(n*f, h, w, c).permute(0, 3, 1, 2)
587
+
588
+ if self.feature_extractor_chunk_size > 0:
589
+ latent_list = []
590
+ hires_list = []
591
+ highest_list = []
592
+ chunk_size = self.feature_extractor_chunk_size
593
+ for start_idx in range(0, video_resize.shape[0], chunk_size):
594
+ video_chunk = video_resize[start_idx:start_idx + chunk_size]
595
+ resnet_out = self.resnet_torch(video_chunk)
596
+
597
+ u3 = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1)
598
+ latent_list.append(u3)
599
+ u1 = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1)
600
+ hires_list.append(u1)
601
+ u0 = resnet_out['resnet_unit_0'].permute(0, 2, 3, 1)
602
+ highest_list.append(u0)
603
+
604
+ latent = torch.cat(latent_list, dim=0)
605
+ hires = torch.cat(hires_list, dim=0)
606
+ highest = torch.cat(highest_list, dim=0)
607
+
608
+ else:
609
+ resnet_out = self.resnet_torch(video_resize)
610
+ latent = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1)
611
+ hires = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1)
612
+ highest = resnet_out['resnet_unit_0'].permute(0, 2, 3, 1)
613
+
614
+ if self.extra_convs:
615
+ latent = self.extra_convs(latent)
616
+
617
+ latent = latent / torch.sqrt(
618
+ torch.maximum(
619
+ torch.sum(torch.square(latent), axis=-1, keepdims=True),
620
+ torch.tensor(1e-12, device=latent.device),
621
+ )
622
+ )
623
+ hires = hires / torch.sqrt(
624
+ torch.maximum(
625
+ torch.sum(torch.square(hires), axis=-1, keepdims=True),
626
+ torch.tensor(1e-12, device=hires.device),
627
+ )
628
+ )
629
+ highest = highest / torch.sqrt(
630
+ torch.maximum(
631
+ torch.sum(torch.square(highest), axis=-1, keepdims=True),
632
+ torch.tensor(1e-12, device=highest.device),
633
+ )
634
+ )
635
+
636
+ latent = latent.view(n, f, *latent.shape[1:])
637
+ hires = hires.view(n, f, *hires.shape[1:])
638
+ highest = highest.view(n, f, *highest.shape[1:])
639
+
640
+ feature_grid.append(latent)
641
+ hires_feats.append(hires)
642
+ highest_feats.append(highest)
643
+ resize_im_shape.append(video_resize.shape[2:4])
644
+
645
+ return FeatureGrids(
646
+ tuple(feature_grid), tuple(hires_feats), tuple(highest_feats), tuple(resize_im_shape)
647
+ )
648
+
649
+ def estimate_trajectories(
650
+ self,
651
+ video_size: Tuple[int, int],
652
+ is_training: bool,
653
+ feature_grids: FeatureGrids,
654
+ query_features: QueryFeatures,
655
+ query_points_in_video: Optional[torch.Tensor],
656
+ query_chunk_size: Optional[int] = None,
657
+ causal_context: Optional[dict[str, torch.Tensor]] = None,
658
+ get_causal_context: bool = False,
659
+ ) -> Mapping[str, Any]:
660
+ """Estimates trajectories given features for a video and query features.
661
+
662
+ Args:
663
+ video_size: A 2-tuple containing the original [height, width] of the
664
+ video. Predictions will be scaled with respect to this resolution.
665
+ is_training: Whether we are training.
666
+ feature_grids: a FeatureGrids object computed for the given video.
667
+ query_features: a QueryFeatures object computed for the query points.
668
+ query_points_in_video: If provided, assume that the query points come from
669
+ the same video as feature_grids, and therefore constrain the resulting
670
+ trajectories to (approximately) pass through them.
671
+ query_chunk_size: When computing cost volumes, break the queries into
672
+ chunks of this size to save memory.
673
+ causal_context: If provided, a dict of causal context to use for
674
+ refinement.
675
+ get_causal_context: If True, return causal context in the output.
676
+
677
+ Returns:
678
+ A dict of outputs, including:
679
+ occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
680
+ where higher indicates more likely to be occluded.
681
+ tracks: predicted point locations, of shape
682
+ [batch, num_queries, num_frames, 2], where each point is [x, y]
683
+ in raster coordinates
684
+ expected_dist: uncertainty estimate logits, of shape
685
+ [batch, num_queries, num_frames], where higher indicates more likely
686
+ to be far from the correct answer.
687
+ """
688
+ del is_training
689
+
690
+ def train2orig(x):
691
+ return utils.convert_grid_coordinates(
692
+ x,
693
+ self.initial_resolution[::-1],
694
+ video_size[::-1],
695
+ coordinate_format='xy',
696
+ )
697
+
698
+ occ_iters = []
699
+ pts_iters = []
700
+ expd_iters = []
701
+ new_causal_context = []
702
+ num_iters = self.num_pips_iter * (len(feature_grids.lowres) - 1)
703
+ for _ in range(num_iters + 1):
704
+ occ_iters.append([])
705
+ pts_iters.append([])
706
+ expd_iters.append([])
707
+ new_causal_context.append([])
708
+ del new_causal_context[-1]
709
+
710
+ infer = functools.partial(
711
+ self.tracks_from_cost_volume,
712
+ im_shp=feature_grids.lowres[0].shape[0:2]
713
+ + self.initial_resolution
714
+ + (3,),
715
+ )
716
+
717
+ num_queries = query_features.lowres[0].shape[1]
718
+ if causal_context is None:
719
+ perm = torch.randperm(num_queries)
720
+ else:
721
+ perm = torch.arange(num_queries)
722
+
723
+ inv_perm = torch.zeros_like(perm)
724
+ inv_perm[perm] = torch.arange(num_queries)
725
+
726
+ for ch in range(0, num_queries, query_chunk_size):
727
+ perm_chunk = perm[ch : ch + query_chunk_size]
728
+ chunk = query_features.lowres[0][:, perm_chunk]
729
+ chunk_hires = query_features.hires[0][:, perm_chunk]
730
+
731
+ cc_chunk = []
732
+ if causal_context is not None:
733
+ for d in range(len(causal_context)):
734
+ tmp_dict = {}
735
+ for k, v in causal_context[d].items():
736
+ tmp_dict[k] = v[:, perm_chunk]
737
+ cc_chunk.append(tmp_dict)
738
+
739
+ if query_points_in_video is not None:
740
+ infer_query_points = query_points_in_video[
741
+ :, perm[ch : ch + query_chunk_size]
742
+ ]
743
+ num_frames = feature_grids.lowres[0].shape[1]
744
+ infer_query_points = utils.convert_grid_coordinates(
745
+ infer_query_points,
746
+ (num_frames,) + video_size,
747
+ (num_frames,) + self.initial_resolution,
748
+ coordinate_format='tyx',
749
+ )
750
+ else:
751
+ infer_query_points = None
752
+
753
+ points, occlusion, expected_dist, cost_volume = infer(
754
+ chunk,
755
+ chunk_hires,
756
+ feature_grids.lowres[0],
757
+ feature_grids.hires[0],
758
+ infer_query_points,
759
+ )
760
+ pts_iters[0].append(train2orig(points))
761
+ occ_iters[0].append(occlusion)
762
+ expd_iters[0].append(expected_dist)
763
+
764
+ mixer_feats = None
765
+ for i in range(num_iters):
766
+ feature_level = i // self.num_pips_iter + 1
767
+ queries = [
768
+ query_features.hires[feature_level][:, perm_chunk],
769
+ query_features.lowres[feature_level][:, perm_chunk],
770
+ query_features.highest[feature_level][:, perm_chunk],
771
+ ]
772
+ supports = [
773
+ query_features.hires_supp[feature_level][:, perm_chunk],
774
+ query_features.lowres_supp[feature_level][:, perm_chunk],
775
+ query_features.highest_supp[feature_level][:, perm_chunk],
776
+ ]
777
+ for _ in range(self.pyramid_level):
778
+ queries.append(queries[-1])
779
+ pyramid = [
780
+ feature_grids.hires[feature_level],
781
+ feature_grids.lowres[feature_level],
782
+ feature_grids.highest[feature_level],
783
+ ]
784
+ for _ in range(self.pyramid_level):
785
+ pyramid.append(
786
+ F.avg_pool3d(
787
+ pyramid[-1],
788
+ kernel_size=(2, 2, 1),
789
+ stride=(2, 2, 1),
790
+ padding=0,
791
+ )
792
+ )
793
+ cc = cc_chunk[i] if causal_context is not None else None
794
+ refined = self.refine_pips(
795
+ queries,
796
+ supports,
797
+ None,
798
+ pyramid,
799
+ points.detach(),
800
+ occlusion.detach(),
801
+ expected_dist.detach(),
802
+ orig_hw=self.initial_resolution,
803
+ last_iter=mixer_feats,
804
+ mixer_iter=i,
805
+ resize_hw=feature_grids.resolutions[feature_level],
806
+ causal_context=cc,
807
+ get_causal_context=get_causal_context,
808
+ cost_volume=cost_volume
809
+ )
810
+ points, occlusion, expected_dist, mixer_feats, new_causal = refined
811
+ pts_iters[i + 1].append(train2orig(points))
812
+ occ_iters[i + 1].append(occlusion)
813
+ expd_iters[i + 1].append(expected_dist)
814
+ new_causal_context[i].append(new_causal)
815
+
816
+ if (i + 1) % self.num_pips_iter == 0:
817
+ mixer_feats = None
818
+ expected_dist = expd_iters[0][-1]
819
+ occlusion = occ_iters[0][-1]
820
+
821
+ occlusion = []
822
+ points = []
823
+ expd = []
824
+ for i, _ in enumerate(occ_iters):
825
+ occlusion.append(torch.cat(occ_iters[i], dim=1)[:, inv_perm])
826
+ points.append(torch.cat(pts_iters[i], dim=1)[:, inv_perm])
827
+ expd.append(torch.cat(expd_iters[i], dim=1)[:, inv_perm])
828
+
829
+ out = dict(
830
+ occlusion=occlusion,
831
+ tracks=points,
832
+ expected_dist=expd,
833
+ )
834
+ return out
835
+
836
+ def refine_pips(
837
+ self,
838
+ target_feature,
839
+ support_feature,
840
+ frame_features,
841
+ pyramid,
842
+ pos_guess,
843
+ occ_guess,
844
+ expd_guess,
845
+ orig_hw,
846
+ last_iter=None,
847
+ mixer_iter=0.0,
848
+ resize_hw=None,
849
+ causal_context=None,
850
+ get_causal_context=False,
851
+ cost_volume=None,
852
+ ):
853
+ del frame_features
854
+ del mixer_iter
855
+ orig_h, orig_w = orig_hw
856
+ resized_h, resized_w = resize_hw
857
+ corrs_pyr = []
858
+ assert len(target_feature) == len(pyramid)
859
+ for pyridx, (query, supp, grid) in enumerate(zip(target_feature, support_feature, pyramid)):
860
+ # note: interp needs [y,x]
861
+ coords = utils.convert_grid_coordinates(
862
+ pos_guess, (orig_w, orig_h), grid.shape[-2:-4:-1]
863
+ )
864
+ coords = torch.flip(coords, dims=(-1,))
865
+
866
+ support_size = 7
867
+ ctxx, ctxy = torch.meshgrid(
868
+ torch.arange(-(support_size // 2), support_size // 2 + 1),
869
+ torch.arange(-(support_size // 2), support_size // 2 + 1),
870
+ indexing='xy',
871
+ )
872
+ ctx = torch.stack([ctxy, ctxx], dim=-1)
873
+ ctx = ctx.reshape(-1, 2).to(coords.device)
874
+ coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
875
+ neighborhood = utils.map_coordinates_2d(grid, coords2)
876
+
877
+ neighborhood = utils.einshape('bnt(hw)c->bnthwc', neighborhood, h=support_size, w=support_size)
878
+ patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
879
+ patches_input = utils.einshape('bnthwij->(bnt)hwij', patches_input)
880
+ patches_emb = self.cmdtop[pyridx](patches_input)
881
+ patches = utils.einshape('(bnt)c->bntc', patches_emb, b=neighborhood.shape[0], n=neighborhood.shape[1])
882
+
883
+ corrs_pyr.append(patches)
884
+ corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
885
+
886
+ corrs_chunked = corrs_pyr
887
+ pos_guess_input = pos_guess
888
+ occ_guess_input = occ_guess[..., None]
889
+ expd_guess_input = expd_guess[..., None]
890
+
891
+ # mlp_input is batch, num_points, num_chunks, frames_per_chunk, channels
892
+ if last_iter is None:
893
+ both_feature = torch.cat([target_feature[0], target_feature[1]], axis=-1)
894
+ mlp_input_features = torch.tile(
895
+ both_feature.unsqueeze(2), (1, 1, corrs_chunked.shape[-2], 1)
896
+ )
897
+ else:
898
+ mlp_input_features = last_iter
899
+
900
+ mlp_input_list = [
901
+ occ_guess_input,
902
+ expd_guess_input,
903
+ corrs_chunked
904
+ ]
905
+
906
+ rel_pos_forward = F.pad(pos_guess_input[..., :-1, :] - pos_guess_input[..., 1:, :], (0, 0, 0, 1))
907
+ rel_pos_backward = F.pad(pos_guess_input[..., 1:, :] - pos_guess_input[..., :-1, :], (0, 0, 1, 0))
908
+ scale = torch.tensor([resized_w / orig_w, resized_h / orig_h]) / torch.tensor([orig_w, orig_h])
909
+ scale = scale.to(pos_guess_input.device)
910
+ rel_pos_forward = rel_pos_forward * scale
911
+ rel_pos_backward = rel_pos_backward * scale
912
+ rel_pos_emb_input = posenc(torch.cat([rel_pos_forward, rel_pos_backward], axis=-1), min_deg=0, max_deg=10) # batch, num_points, num_frames, 84
913
+ mlp_input_list.append(rel_pos_emb_input)
914
+ mlp_input = torch.cat(mlp_input_list, axis=-1)
915
+
916
+ x = utils.einshape('bnfc->(bn)fc', mlp_input)
917
+
918
+ if causal_context is not None:
919
+ for k, v in causal_context.items():
920
+ causal_context[k] = utils.einshape('bn...->(bn)...', v)
921
+ res = self.torch_pips_mixer(x)
922
+
923
+ res = utils.einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0])
924
+
925
+ pos_update = utils.convert_grid_coordinates(
926
+ res[..., :2],
927
+ (resized_w, resized_h),
928
+ (orig_w, orig_h),
929
+ )
930
+ return (
931
+ pos_update + pos_guess,
932
+ res[..., 2] + occ_guess,
933
+ res[..., 3] + expd_guess,
934
+ res[..., 4:] + (mlp_input_features if last_iter is None else last_iter),
935
+ None,
936
+ )
937
+
938
+ def tracks_from_cost_volume(
939
+ self,
940
+ interp_feature: torch.Tensor,
941
+ interp_feature_hires: torch.Tensor,
942
+ feature_grid: torch.Tensor,
943
+ feature_grid_hires: torch.Tensor,
944
+ query_points: Optional[torch.Tensor],
945
+ im_shp=None,
946
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
947
+ """Converts features into tracks by computing a cost volume.
948
+
949
+ The computed cost volume will have shape
950
+ [batch, num_queries, time, height, width], which can be very
951
+ memory intensive.
952
+
953
+ Args:
954
+ interp_feature: A tensor of features for each query point, of shape
955
+ [batch, num_queries, channels, heads].
956
+ feature_grid: A tensor of features for the video, of shape [batch, time,
957
+ height, width, channels, heads].
958
+ query_points: When computing tracks, we assume these points are given as
959
+ ground truth and we reproduce them exactly. This is a set of points of
960
+ shape [batch, num_points, 3], where each entry is [t, y, x] in frame/
961
+ raster coordinates.
962
+ im_shp: The shape of the original image, i.e., [batch, num_frames, time,
963
+ height, width, 3].
964
+
965
+ Returns:
966
+ A 2-tuple of the inferred points (of shape
967
+ [batch, num_points, num_frames, 2] where each point is [x, y]) and
968
+ inferred occlusion (of shape [batch, num_points, num_frames], where
969
+ each is a logit where higher means occluded)
970
+ """
971
+
972
+ cost_volume = torch.einsum(
973
+ 'bnc,bthwc->tbnhw',
974
+ interp_feature,
975
+ feature_grid,
976
+ )
977
+ cost_volume_hires = torch.einsum(
978
+ 'bnc,bthwc->tbnhw',
979
+ interp_feature_hires,
980
+ feature_grid_hires,
981
+ )
982
+
983
+ shape = cost_volume.shape
984
+ batch_size, num_points = cost_volume.shape[1:3]
985
+
986
+ interp_cost = utils.einshape('tbnhw->(tbn)1hw', cost_volume)
987
+ interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
988
+ # TODO: not sure if this is correct
989
+ interp_cost = utils.einshape('(tbn)1hw->tbnhw', interp_cost, b=batch_size, n=num_points)
990
+ cost_volume_stack = torch.stack(
991
+ [
992
+ # jax.image.resize(cost_volume, cost_volume_hires.shape, method='bilinear'),
993
+ interp_cost,
994
+ cost_volume_hires,
995
+ ], dim=-1
996
+ )
997
+ pos = utils.einshape('tbnhwc->(tbn)chw', cost_volume_stack)
998
+ pos = self.cost_conv(pos)
999
+ pos = utils.einshape('(tbn)1hw->bnthw', pos, b=batch_size, n=num_points)
1000
+
1001
+ pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
1002
+ softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
1003
+ pos = softmaxed.view_as(pos)
1004
+
1005
+ points = utils.heatmaps_to_points(pos, im_shp, query_points=query_points)
1006
+
1007
+ occlusion = torch.cat(
1008
+ [
1009
+ torch.mean(cost_volume_stack, dim=(-2, -3)),
1010
+ torch.amax(cost_volume_stack, dim=(-2, -3)),
1011
+ torch.amin(cost_volume_stack, dim=(-2, -3)),
1012
+ ], dim=-1
1013
+ )
1014
+ occlusion = self.occ_linear(occlusion)
1015
+ expected_dist = utils.einshape(
1016
+ 'tbn1->bnt', occlusion[..., 1:2]
1017
+ )
1018
+ occlusion = utils.einshape(
1019
+ 'tbn1->bnt', occlusion[..., 0:1]
1020
+ )
1021
+
1022
+ return points, occlusion, expected_dist, utils.einshape('tbnhw->bnthw', cost_volume)
1023
+
1024
+ def construct_initial_causal_state(self, num_points, num_resolutions=1):
1025
+ """Construct initial causal state."""
1026
+ value_shapes = {}
1027
+ for i in range(self.num_mixer_blocks):
1028
+ value_shapes[f'block_{i}_causal_1'] = (1, num_points, 2, 512)
1029
+ value_shapes[f'block_{i}_causal_2'] = (1, num_points, 2, 2048)
1030
+ fake_ret = {
1031
+ k: torch.zeros(v, dtype=torch.float32) for k, v in value_shapes.items()
1032
+ }
1033
+ return [fake_ret] * num_resolutions * 4
1034
+
1035
+
1036
+ CHECKPOINT_LINK = {
1037
+ 'small': 'https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_small.ckpt',
1038
+ 'base': 'https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_base.ckpt',
1039
+ }
1040
+
1041
+ def load_model(ckpt_path=None, model_size='base'):
1042
+ if ckpt_path is None:
1043
+ ckpt_link = CHECKPOINT_LINK[model_size]
1044
+ state_dict = torch.hub.load_state_dict_from_url(ckpt_link, map_location='cpu')['state_dict']
1045
+ else:
1046
+ state_dict = torch.load(ckpt_path)['state_dict']
1047
+ state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
1048
+
1049
+ model = LocoTrack(model_size=model_size)
1050
+ model.load_state_dict(state_dict)
1051
+ model.eval()
1052
+
1053
+ return model
locotrack_pytorch/models/nets.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Pytorch neural network definitions."""
17
+
18
+ from typing import Sequence, Union
19
+
20
+ import torch
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+
24
+ from models.utils import Conv2dSamePadding
25
+
26
+
27
+ class ExtraConvBlock(nn.Module):
28
+ """Additional convolution block."""
29
+
30
+ def __init__(
31
+ self,
32
+ channel_dim,
33
+ channel_multiplier,
34
+ ):
35
+ super().__init__()
36
+ self.channel_dim = channel_dim
37
+ self.channel_multiplier = channel_multiplier
38
+
39
+ self.layer_norm = nn.LayerNorm(
40
+ normalized_shape=channel_dim, elementwise_affine=True, bias=True
41
+ )
42
+ self.conv = nn.Conv2d(
43
+ self.channel_dim,
44
+ self.channel_dim * self.channel_multiplier,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ )
49
+ self.conv_1 = nn.Conv2d(
50
+ self.channel_dim * self.channel_multiplier,
51
+ self.channel_dim,
52
+ kernel_size=3,
53
+ stride=1,
54
+ padding=1,
55
+ )
56
+
57
+ def forward(self, x):
58
+ x = self.layer_norm(x)
59
+ x = x.permute(0, 3, 1, 2)
60
+ res = self.conv(x)
61
+ res = F.gelu(res, approximate='tanh')
62
+ x = x + self.conv_1(res)
63
+ x = x.permute(0, 2, 3, 1)
64
+ return x
65
+
66
+
67
+ class ExtraConvs(nn.Module):
68
+ """Additional CNN."""
69
+
70
+ def __init__(
71
+ self,
72
+ num_layers=5,
73
+ channel_dim=256,
74
+ channel_multiplier=4,
75
+ ):
76
+ super().__init__()
77
+ self.num_layers = num_layers
78
+ self.channel_dim = channel_dim
79
+ self.channel_multiplier = channel_multiplier
80
+
81
+ self.blocks = nn.ModuleList()
82
+ for _ in range(self.num_layers):
83
+ self.blocks.append(
84
+ ExtraConvBlock(self.channel_dim, self.channel_multiplier)
85
+ )
86
+
87
+ def forward(self, x):
88
+ for block in self.blocks:
89
+ x = block(x)
90
+
91
+ return x
92
+
93
+
94
+ class ConvChannelsMixer(nn.Module):
95
+ """Linear activation block for PIPs's MLP Mixer."""
96
+
97
+ def __init__(self, in_channels):
98
+ super().__init__()
99
+ self.mlp2_up = nn.Linear(in_channels, in_channels * 4)
100
+ self.mlp2_down = nn.Linear(in_channels * 4, in_channels)
101
+
102
+ def forward(self, x):
103
+ x = self.mlp2_up(x)
104
+ x = F.gelu(x, approximate='tanh')
105
+ x = self.mlp2_down(x)
106
+ return x
107
+
108
+
109
+ class PIPsConvBlock(nn.Module):
110
+ """Convolutional block for PIPs's MLP Mixer."""
111
+
112
+ def __init__(
113
+ self, in_channels, kernel_shape=3, use_causal_conv=False, block_idx=None
114
+ ):
115
+ super().__init__()
116
+ self.use_causal_conv = use_causal_conv
117
+ self.block_name = f'block_{block_idx}'
118
+ self.kernel_shape = kernel_shape
119
+
120
+ self.layer_norm = nn.LayerNorm(
121
+ normalized_shape=in_channels, elementwise_affine=True, bias=False
122
+ )
123
+ self.mlp1_up = nn.Conv1d(
124
+ in_channels,
125
+ in_channels * 4,
126
+ kernel_shape,
127
+ stride=1,
128
+ padding=0 if self.use_causal_conv else 1,
129
+ groups=in_channels,
130
+ )
131
+
132
+ self.mlp1_up_1 = nn.Conv1d(
133
+ in_channels * 4,
134
+ in_channels * 4,
135
+ kernel_shape,
136
+ stride=1,
137
+ padding=0 if self.use_causal_conv else 1,
138
+ groups=in_channels * 4,
139
+ )
140
+ self.layer_norm_1 = nn.LayerNorm(
141
+ normalized_shape=in_channels, elementwise_affine=True, bias=False
142
+ )
143
+ self.conv_channels_mixer = ConvChannelsMixer(in_channels)
144
+
145
+ def forward(self, x, causal_context=None, get_causal_context=False):
146
+ to_skip = x
147
+ x = self.layer_norm(x)
148
+ new_causal_context = {}
149
+ num_extra = 0
150
+
151
+ if causal_context is not None:
152
+ name1 = self.block_name + '_causal_1'
153
+ x = torch.cat([causal_context[name1], x], dim=-2)
154
+ num_extra = causal_context[name1].shape[-2]
155
+ new_causal_context[name1] = x[..., -(self.kernel_shape - 1) :, :]
156
+
157
+ x = x.permute(0, 2, 1)
158
+ if self.use_causal_conv:
159
+ x = F.pad(x, (2, 0))
160
+ x = self.mlp1_up(x)
161
+
162
+ x = F.gelu(x, approximate='tanh')
163
+
164
+ if causal_context is not None:
165
+ x = x.permute(0, 2, 1)
166
+ name2 = self.block_name + '_causal_2'
167
+ num_extra = causal_context[name2].shape[-2]
168
+ x = torch.cat([causal_context[name2], x[..., num_extra:, :]], dim=-2)
169
+ new_causal_context[name2] = x[..., -(self.kernel_shape - 1) :, :]
170
+ x = x.permute(0, 2, 1)
171
+
172
+ if self.use_causal_conv:
173
+ x = F.pad(x, (2, 0))
174
+ x = self.mlp1_up_1(x)
175
+ x = x.permute(0, 2, 1)
176
+
177
+ if causal_context is not None:
178
+ x = x[..., num_extra:, :]
179
+
180
+ x = x[..., 0::4] + x[..., 1::4] + x[..., 2::4] + x[..., 3::4]
181
+
182
+ x = x + to_skip
183
+ to_skip = x
184
+ x = self.layer_norm_1(x)
185
+ x = self.conv_channels_mixer(x)
186
+
187
+ x = x + to_skip
188
+ return x, new_causal_context
189
+
190
+
191
+ class PIPSMLPMixer(nn.Module):
192
+ """Depthwise-conv version of PIPs's MLP Mixer."""
193
+
194
+ def __init__(
195
+ self,
196
+ input_channels: int,
197
+ output_channels: int,
198
+ hidden_dim: int = 512,
199
+ num_blocks: int = 12,
200
+ kernel_shape: int = 3,
201
+ use_causal_conv: bool = False,
202
+ ):
203
+ """Inits Mixer module.
204
+
205
+ A depthwise-convolutional version of a MLP Mixer for processing images.
206
+
207
+ Args:
208
+ input_channels (int): The number of input channels.
209
+ output_channels (int): The number of output channels.
210
+ hidden_dim (int, optional): The dimension of the hidden layer. Defaults
211
+ to 512.
212
+ num_blocks (int, optional): The number of convolution blocks in the
213
+ mixer. Defaults to 12.
214
+ kernel_shape (int, optional): The size of the kernel in the convolution
215
+ blocks. Defaults to 3.
216
+ use_causal_conv (bool, optional): Whether to use causal convolutions.
217
+ Defaults to False.
218
+ """
219
+
220
+ super().__init__()
221
+ self.hidden_dim = hidden_dim
222
+ self.num_blocks = num_blocks
223
+ self.use_causal_conv = use_causal_conv
224
+ self.linear = nn.Linear(input_channels, self.hidden_dim)
225
+ self.layer_norm = nn.LayerNorm(
226
+ normalized_shape=hidden_dim, elementwise_affine=True, bias=False
227
+ )
228
+ self.linear_1 = nn.Linear(hidden_dim, output_channels)
229
+ self.blocks = nn.ModuleList([
230
+ PIPsConvBlock(
231
+ hidden_dim, kernel_shape, self.use_causal_conv, block_idx=i
232
+ )
233
+ for i in range(num_blocks)
234
+ ])
235
+
236
+ def forward(self, x, causal_context=None, get_causal_context=False):
237
+ x = self.linear(x)
238
+ all_causal_context = {}
239
+ for block in self.blocks:
240
+ x, new_causal_context = block(x, causal_context, get_causal_context)
241
+ if get_causal_context:
242
+ all_causal_context.update(new_causal_context)
243
+
244
+ x = self.layer_norm(x)
245
+ x = self.linear_1(x)
246
+ return x, all_causal_context
247
+
248
+
249
+ class BlockV2(nn.Module):
250
+ """ResNet V2 block."""
251
+
252
+ def __init__(
253
+ self,
254
+ channels_in: int,
255
+ channels_out: int,
256
+ stride: Union[int, Sequence[int]],
257
+ use_projection: bool,
258
+ ):
259
+ super().__init__()
260
+ self.padding = (1, 1, 1, 1)
261
+ # Handle assymetric padding created by padding="SAME" in JAX/LAX
262
+ if stride == 1:
263
+ self.padding = (1, 1, 1, 1)
264
+ elif stride == 2:
265
+ self.padding = (0, 2, 0, 2)
266
+ else:
267
+ raise ValueError(
268
+ 'Check correct padding using padtype_to_padsin jax._src.lax.lax'
269
+ )
270
+
271
+ self.use_projection = use_projection
272
+ if self.use_projection:
273
+ self.proj_conv = Conv2dSamePadding(
274
+ in_channels=channels_in,
275
+ out_channels=channels_out,
276
+ kernel_size=1,
277
+ stride=stride,
278
+ padding=0,
279
+ bias=False,
280
+ )
281
+
282
+ self.bn_0 = nn.InstanceNorm2d(
283
+ num_features=channels_in,
284
+ eps=1e-05,
285
+ momentum=0.1,
286
+ affine=True,
287
+ track_running_stats=False,
288
+ )
289
+ self.conv_0 = Conv2dSamePadding(
290
+ in_channels=channels_in,
291
+ out_channels=channels_out,
292
+ kernel_size=3,
293
+ stride=stride,
294
+ padding=0,
295
+ bias=False,
296
+ )
297
+
298
+ self.conv_1 = Conv2dSamePadding(
299
+ in_channels=channels_out,
300
+ out_channels=channels_out,
301
+ kernel_size=3,
302
+ stride=1,
303
+ padding=1,
304
+ bias=False,
305
+ )
306
+ self.bn_1 = nn.InstanceNorm2d(
307
+ num_features=channels_out,
308
+ eps=1e-05,
309
+ momentum=0.1,
310
+ affine=True,
311
+ track_running_stats=False,
312
+ )
313
+
314
+ def forward(self, inputs):
315
+ x = shortcut = inputs
316
+
317
+ x = self.bn_0(x)
318
+ x = torch.relu(x)
319
+ if self.use_projection:
320
+ shortcut = self.proj_conv(x)
321
+
322
+ x = self.conv_0(x)
323
+
324
+ x = self.bn_1(x)
325
+ x = torch.relu(x)
326
+ # no issues with padding here as this layer always has stride 1
327
+ x = self.conv_1(x)
328
+
329
+ return x + shortcut
330
+
331
+
332
+ class BlockGroup(nn.Module):
333
+ """Higher level block for ResNet implementation."""
334
+
335
+ def __init__(
336
+ self,
337
+ channels_in: int,
338
+ channels_out: int,
339
+ num_blocks: int,
340
+ stride: Union[int, Sequence[int]],
341
+ use_projection: bool,
342
+ ):
343
+ super().__init__()
344
+ blocks = []
345
+ for i in range(num_blocks):
346
+ blocks.append(
347
+ BlockV2(
348
+ channels_in=channels_in if i == 0 else channels_out,
349
+ channels_out=channels_out,
350
+ stride=(1 if i else stride),
351
+ use_projection=(i == 0 and use_projection),
352
+ )
353
+ )
354
+ self.blocks = nn.ModuleList(blocks)
355
+
356
+ def forward(self, inputs):
357
+ out = inputs
358
+ for block in self.blocks:
359
+ out = block(out)
360
+ return out
361
+
362
+
363
+ class ResNet(nn.Module):
364
+ """ResNet model."""
365
+
366
+ def __init__(
367
+ self,
368
+ blocks_per_group: Sequence[int],
369
+ channels_per_group: Sequence[int] = (64, 128, 256, 512),
370
+ use_projection: Sequence[bool] = (True, True, True, True),
371
+ strides: Sequence[int] = (1, 2, 2, 2),
372
+ ):
373
+ """Initializes a ResNet model with customizable layers and configurations.
374
+
375
+ This constructor allows defining the architecture of a ResNet model by
376
+ setting the number of blocks, channels, projection usage, and strides for
377
+ each group of blocks within the network. It provides flexibility in
378
+ creating various ResNet configurations.
379
+
380
+ Args:
381
+ blocks_per_group: A sequence of 4 integers, each indicating the number
382
+ of residual blocks in each group.
383
+ channels_per_group: A sequence of 4 integers, each specifying the number
384
+ of output channels for the blocks in each group. Defaults to (64, 128,
385
+ 256, 512).
386
+ use_projection: A sequence of 4 booleans, each indicating whether to use
387
+ a projection shortcut (True) or an identity shortcut (False) in each
388
+ group. Defaults to (True, True, True, True).
389
+ strides: A sequence of 4 integers, each specifying the stride size for
390
+ the convolutions in each group. Defaults to (1, 2, 2, 2).
391
+
392
+ The ResNet model created will have 4 groups, with each group's
393
+ architecture defined by the corresponding elements in these sequences.
394
+ """
395
+ super().__init__()
396
+
397
+ self.initial_conv = Conv2dSamePadding(
398
+ in_channels=3,
399
+ out_channels=channels_per_group[0],
400
+ kernel_size=(7, 7),
401
+ stride=2,
402
+ padding=0,
403
+ bias=False,
404
+ )
405
+
406
+ block_groups = []
407
+ for i, _ in enumerate(strides):
408
+ block_groups.append(
409
+ BlockGroup(
410
+ channels_in=channels_per_group[i - 1] if i > 0 else 64,
411
+ channels_out=channels_per_group[i],
412
+ num_blocks=blocks_per_group[i],
413
+ stride=strides[i],
414
+ use_projection=use_projection[i],
415
+ )
416
+ )
417
+ self.block_groups = nn.ModuleList(block_groups)
418
+
419
+ def forward(self, inputs):
420
+ result = {}
421
+ out = inputs
422
+ out = self.initial_conv(out)
423
+ result['initial_conv'] = out
424
+
425
+ for block_id, block_group in enumerate(self.block_groups):
426
+ out = block_group(out)
427
+ result[f'resnet_unit_{block_id}'] = out
428
+
429
+ return result
locotrack_pytorch/models/utils.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Pytorch model utilities."""
17
+ import math
18
+ from typing import Any, Sequence, Union
19
+ from einshape.src import abstract_ops
20
+ from einshape.src import backend
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+
25
+
26
+ def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
27
+ """Resizes a 5D tensor using bilinear interpolation.
28
+
29
+ Args:
30
+ x: A 5D tensor of shape (B, T, W, H, C) where B is batch size, T is
31
+ time, W is width, H is height, and C is the number of channels.
32
+ resolution: The target resolution as a tuple (new_width, new_height).
33
+
34
+ Returns:
35
+ The resized tensor.
36
+ """
37
+ b, t, h, w, c = x.size()
38
+ x = x.permute(0, 1, 4, 2, 3).reshape(b, t * c, h, w)
39
+ x = F.interpolate(x, size=resolution, mode='bilinear', align_corners=False)
40
+ b, _, h, w = x.size()
41
+ x = x.reshape(b, t, c, h, w).permute(0, 1, 3, 4, 2)
42
+ return x
43
+
44
+
45
+ def map_coordinates_3d(
46
+ feats: torch.Tensor, coordinates: torch.Tensor
47
+ ) -> torch.Tensor:
48
+ """Maps 3D coordinates to corresponding features using bilinear interpolation.
49
+
50
+ Args:
51
+ feats: A 5D tensor of features with shape (B, W, H, D, C), where B is batch
52
+ size, W is width, H is height, D is depth, and C is the number of
53
+ channels.
54
+ coordinates: A 3D tensor of coordinates with shape (B, N, 3), where N is the
55
+ number of coordinates and the last dimension represents (W, H, D)
56
+ coordinates.
57
+
58
+ Returns:
59
+ The mapped features tensor.
60
+ """
61
+ x = feats.permute(0, 4, 1, 2, 3)
62
+ y = coordinates[:, :, None, None, :].float().clone()
63
+ y[..., 0] = y[..., 0] + 0.5
64
+ y = 2 * (y / torch.tensor(x.shape[2:], device=y.device)) - 1
65
+ y = torch.flip(y, dims=(-1,))
66
+ out = (
67
+ F.grid_sample(
68
+ x, y, mode='bilinear', align_corners=False, padding_mode='border'
69
+ )
70
+ .squeeze(dim=(3, 4))
71
+ .permute(0, 2, 1)
72
+ )
73
+ return out
74
+
75
+
76
+ def map_coordinates_2d(
77
+ feats: torch.Tensor, coordinates: torch.Tensor
78
+ ) -> torch.Tensor:
79
+ """Maps 2D coordinates to feature maps using bilinear interpolation.
80
+
81
+ The function performs bilinear interpolation on the feature maps (`feats`)
82
+ at the specified `coordinates`. The coordinates are normalized between
83
+ -1 and 1 The result is a tensor of sampled features corresponding
84
+ to these coordinates.
85
+
86
+ Args:
87
+ feats (Tensor): A 5D tensor of shape (N, T, H, W, C) representing feature
88
+ maps, where N is the batch size, T is the number of frames, H and W are
89
+ height and width, and C is the number of channels.
90
+ coordinates (Tensor): A 5D tensor of shape (N, P, T, S, XY) representing
91
+ coordinates, where N is the batch size, P is the number of points, T is
92
+ the number of frames, S is the number of samples, and XY represents the 2D
93
+ coordinates.
94
+
95
+ Returns:
96
+ Tensor: A 5D tensor of the sampled features corresponding to the
97
+ given coordinates, of shape (N, P, T, S, C).
98
+ """
99
+ n, t, h, w, c = feats.shape
100
+ x = feats.permute(0, 1, 4, 2, 3).view(n * t, c, h, w)
101
+
102
+ n, p, t, s, xy = coordinates.shape
103
+ y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
104
+ y = 2 * (y / h) - 1
105
+ y = torch.flip(y, dims=(-1,)).float()
106
+
107
+ out = F.grid_sample(
108
+ x, y, mode='bilinear', align_corners=False, padding_mode='zeros'
109
+ )
110
+ _, c, _, _ = out.shape
111
+ out = out.permute(0, 2, 3, 1).view(n, t, p, s, c).permute(0, 2, 1, 3, 4)
112
+
113
+ return out
114
+
115
+
116
+ def soft_argmax_heatmap_batched(softmax_val, threshold=5):
117
+ """Test if two image resolutions are the same."""
118
+ b, h, w, d1, d2 = softmax_val.shape
119
+ y, x = torch.meshgrid(
120
+ torch.arange(d1, device=softmax_val.device),
121
+ torch.arange(d2, device=softmax_val.device),
122
+ indexing='ij',
123
+ )
124
+ coords = torch.stack([x + 0.5, y + 0.5], dim=-1).to(softmax_val.device)
125
+ softmax_val_flat = softmax_val.reshape(b, h, w, -1)
126
+ argmax_pos = torch.argmax(softmax_val_flat, dim=-1)
127
+
128
+ pos = coords.reshape(-1, 2)[argmax_pos]
129
+ valid = (
130
+ torch.sum(
131
+ torch.square(
132
+ coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :]
133
+ ),
134
+ dim=-1,
135
+ keepdims=True,
136
+ )
137
+ < threshold**2
138
+ )
139
+
140
+ weighted_sum = torch.sum(
141
+ coords[None, None, None, :, :, :]
142
+ * valid
143
+ * softmax_val[:, :, :, :, :, None],
144
+ dim=(3, 4),
145
+ )
146
+ sum_of_weights = torch.maximum(
147
+ torch.sum(valid * softmax_val[:, :, :, :, :, None], dim=(3, 4)),
148
+ torch.tensor(1e-12, device=softmax_val.device),
149
+ )
150
+ return weighted_sum / sum_of_weights
151
+
152
+
153
+ def heatmaps_to_points(
154
+ all_pairs_softmax,
155
+ image_shape,
156
+ threshold=5,
157
+ query_points=None,
158
+ ):
159
+ """Convert heatmaps to points using soft argmax."""
160
+
161
+ out_points = soft_argmax_heatmap_batched(all_pairs_softmax, threshold)
162
+ feature_grid_shape = all_pairs_softmax.shape[1:]
163
+ # Note: out_points is now [x, y]; we need to divide by [width, height].
164
+ # image_shape[3] is width and image_shape[2] is height.
165
+ out_points = convert_grid_coordinates(
166
+ out_points,
167
+ feature_grid_shape[3:1:-1],
168
+ image_shape[3:1:-1],
169
+ )
170
+ assert feature_grid_shape[1] == image_shape[1]
171
+ if query_points is not None:
172
+ # The [..., 0:1] is because we only care about the frame index.
173
+ query_frame = convert_grid_coordinates(
174
+ query_points.detach(),
175
+ image_shape[1:4],
176
+ feature_grid_shape[1:4],
177
+ coordinate_format='tyx',
178
+ )[..., 0:1]
179
+
180
+ query_frame = torch.round(query_frame)
181
+ frame_indices = torch.arange(image_shape[1], device=query_frame.device)[
182
+ None, None, :
183
+ ]
184
+ is_query_point = query_frame == frame_indices
185
+
186
+ is_query_point = is_query_point[:, :, :, None]
187
+ out_points = (
188
+ out_points * ~is_query_point
189
+ + torch.flip(query_points[:, :, None], dims=(-1,))[..., 0:2]
190
+ * is_query_point
191
+ )
192
+
193
+ return out_points
194
+
195
+
196
+ def is_same_res(r1, r2):
197
+ """Test if two image resolutions are the same."""
198
+ return all([x == y for x, y in zip(r1, r2)])
199
+
200
+
201
+ def convert_grid_coordinates(
202
+ coords: torch.Tensor,
203
+ input_grid_size: Sequence[int],
204
+ output_grid_size: Sequence[int],
205
+ coordinate_format: str = 'xy',
206
+ ) -> torch.Tensor:
207
+ """Convert grid coordinates to correct format."""
208
+ if isinstance(input_grid_size, tuple):
209
+ input_grid_size = torch.tensor(input_grid_size, device=coords.device)
210
+ if isinstance(output_grid_size, tuple):
211
+ output_grid_size = torch.tensor(output_grid_size, device=coords.device)
212
+
213
+ if coordinate_format == 'xy':
214
+ if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
215
+ raise ValueError(
216
+ 'If coordinate_format is xy, the shapes must be length 2.'
217
+ )
218
+ elif coordinate_format == 'tyx':
219
+ if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
220
+ raise ValueError(
221
+ 'If coordinate_format is tyx, the shapes must be length 3.'
222
+ )
223
+ if input_grid_size[0] != output_grid_size[0]:
224
+ raise ValueError('converting frame count is not supported.')
225
+ else:
226
+ raise ValueError('Recognized coordinate formats are xy and tyx.')
227
+
228
+ position_in_grid = coords
229
+ position_in_grid = position_in_grid * output_grid_size / input_grid_size
230
+
231
+ return position_in_grid
232
+
233
+
234
+ class _JaxBackend(backend.Backend[torch.Tensor]):
235
+ """Einshape implementation for PyTorch."""
236
+
237
+ # https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py
238
+
239
+ def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor:
240
+ return x.reshape(op.shape)
241
+
242
+ def transpose(
243
+ self, x: torch.Tensor, op: abstract_ops.Transpose
244
+ ) -> torch.Tensor:
245
+ return x.permute(op.perm)
246
+
247
+ def broadcast(
248
+ self, x: torch.Tensor, op: abstract_ops.Broadcast
249
+ ) -> torch.Tensor:
250
+ shape = op.transform_shape(x.shape)
251
+ for axis_position in sorted(op.axis_sizes.keys()):
252
+ x = x.unsqueeze(axis_position)
253
+ return x.expand(shape)
254
+
255
+
256
+ def einshape(
257
+ equation: str, value: Union[torch.Tensor, Any], **index_sizes: int
258
+ ) -> torch.Tensor:
259
+ """Reshapes `value` according to the given Shape Equation.
260
+
261
+ Args:
262
+ equation: The Shape Equation specifying the index regrouping and reordering.
263
+ value: Input tensor, or tensor-like object.
264
+ **index_sizes: Sizes of indices, where they cannot be inferred from
265
+ `input_shape`.
266
+
267
+ Returns:
268
+ Tensor derived from `value` by reshaping as specified by `equation`.
269
+ """
270
+ if not isinstance(value, torch.Tensor):
271
+ value = torch.tensor(value)
272
+ return _JaxBackend().exec(equation, value, value.shape, **index_sizes)
273
+
274
+
275
+ def generate_default_resolutions(full_size, train_size, num_levels=None):
276
+ """Generate a list of logarithmically-spaced resolutions.
277
+
278
+ Generated resolutions are between train_size and full_size, inclusive, with
279
+ num_levels different resolutions total. Useful for generating the input to
280
+ refinement_resolutions in PIPs.
281
+
282
+ Args:
283
+ full_size: 2-tuple of ints. The full image size desired.
284
+ train_size: 2-tuple of ints. The smallest refinement level. Should
285
+ typically match the training resolution, which is (256, 256) for TAPIR.
286
+ num_levels: number of levels. Typically each resolution should be less than
287
+ twice the size of prior resolutions.
288
+
289
+ Returns:
290
+ A list of resolutions.
291
+ """
292
+ if all([x == y for x, y in zip(train_size, full_size)]):
293
+ return [train_size]
294
+
295
+ if num_levels is None:
296
+ size_ratio = np.array(full_size) / np.array(train_size)
297
+ num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1)
298
+
299
+ if num_levels <= 1:
300
+ return [train_size]
301
+
302
+ h, w = full_size[0:2]
303
+ if h % 8 != 0 or w % 8 != 0:
304
+ print(
305
+ 'Warning: output size is not a multiple of 8. Final layer '
306
+ + 'will round size down.'
307
+ )
308
+ ll_h, ll_w = train_size[0:2]
309
+
310
+ sizes = []
311
+ for i in range(num_levels):
312
+ size = (
313
+ int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8,
314
+ int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8,
315
+ )
316
+ sizes.append(size)
317
+ return sizes
318
+
319
+
320
+ class Conv2dSamePadding(torch.nn.Conv2d):
321
+
322
+ def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
323
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
324
+
325
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
326
+ ih, iw = x.size()[-2:]
327
+
328
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
329
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
330
+
331
+ if pad_h > 0 or pad_w > 0:
332
+ x = F.pad(
333
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
334
+ )
335
+ return F.conv2d(
336
+ x,
337
+ self.weight,
338
+ self.bias,
339
+ self.stride,
340
+ # self.padding,
341
+ 0,
342
+ self.dilation,
343
+ self.groups,
344
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ einshape==1.0
2
+ gradio==4.40.0
3
+ mediapy==1.2.2
4
+ opencv-python==4.10.0.84
5
+ torch==2.4.0
6
+ torchaudio==2.4.0
7
+ torchvision==0.19.0
viz_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utility functions."""
2
+
3
+ import colorsys
4
+ import random
5
+ from typing import List, Optional, Sequence, Tuple
6
+
7
+ import numpy as np
8
+
9
+
10
+ # Generate random colormaps for visualizing different points.
11
+ def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
12
+ """Gets colormap for points."""
13
+ colors = []
14
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
15
+ hue = i / 360.0
16
+ lightness = (50 + np.random.rand() * 10) / 100.0
17
+ saturation = (90 + np.random.rand() * 10) / 100.0
18
+ color = colorsys.hls_to_rgb(hue, lightness, saturation)
19
+ colors.append(
20
+ (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
21
+ )
22
+ random.shuffle(colors)
23
+ return colors
24
+
25
+
26
+ def paint_point_track(
27
+ frames: np.ndarray,
28
+ point_tracks: np.ndarray,
29
+ visibles: np.ndarray,
30
+ colormap: Optional[List[Tuple[int, int, int]]] = None,
31
+ ) -> np.ndarray:
32
+ """Converts a sequence of points to color code video.
33
+
34
+ Args:
35
+ frames: [num_frames, height, width, 3], np.uint8, [0, 255]
36
+ point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
37
+ visibles: [num_points, num_frames], bool
38
+ colormap: colormap for points, each point has a different RGB color.
39
+
40
+ Returns:
41
+ video: [num_frames, height, width, 3], np.uint8, [0, 255]
42
+ """
43
+ num_points, num_frames = point_tracks.shape[0:2]
44
+ if colormap is None:
45
+ colormap = get_colors(num_colors=num_points)
46
+ height, width = frames.shape[1:3]
47
+ dot_size_as_fraction_of_min_edge = 0.015
48
+ radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
49
+ diam = radius * 2 + 1
50
+ quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
51
+ quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
52
+ icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
53
+ sharpness = 0.15
54
+ icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
55
+ icon = 1 - icon[:, :, np.newaxis]
56
+ icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
57
+ icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
58
+ icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
59
+ icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
60
+
61
+ video = frames.copy()
62
+ for t in range(num_frames):
63
+ # Pad so that points that extend outside the image frame don't crash us
64
+ image = np.pad(
65
+ video[t],
66
+ [
67
+ (radius + 1, radius + 1),
68
+ (radius + 1, radius + 1),
69
+ (0, 0),
70
+ ],
71
+ )
72
+ for i in range(num_points):
73
+ # The icon is centered at the center of a pixel, but the input coordinates
74
+ # are raster coordinates. Therefore, to render a point at (1,1) (which
75
+ # lies on the corner between four pixels), we need 1/4 of the icon placed
76
+ # centered on the 0'th row, 0'th column, etc. We need to subtract
77
+ # 0.5 to make the fractional position come out right.
78
+ x, y = point_tracks[i, t, :] + 0.5
79
+ x = min(max(x, 0.0), width)
80
+ y = min(max(y, 0.0), height)
81
+
82
+ if visibles[i, t]:
83
+ x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
84
+ x2, y2 = x1 + 1, y1 + 1
85
+
86
+ # bilinear interpolation
87
+ patch = (
88
+ icon1 * (x2 - x) * (y2 - y)
89
+ + icon2 * (x2 - x) * (y - y1)
90
+ + icon3 * (x - x1) * (y2 - y)
91
+ + icon4 * (x - x1) * (y - y1)
92
+ )
93
+ x_ub = x1 + 2 * radius + 2
94
+ y_ub = y1 + 2 * radius + 2
95
+ image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
96
+ y1:y_ub, x1:x_ub, :
97
+ ] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
98
+
99
+ # Remove the pad
100
+ video[t] = image[
101
+ radius + 1 : -radius - 1, radius + 1 : -radius - 1
102
+ ].astype(np.uint8)
103
+ return video
104
+