Seokju Cho
commited on
Commit
·
f1586f7
1
Parent(s):
058b9ed
initial commit
Browse files- .gitignore +162 -0
- app.py +444 -0
- locotrack_pytorch/README.md +62 -0
- locotrack_pytorch/config/default.ini +25 -0
- locotrack_pytorch/data/evaluation_datasets.py +784 -0
- locotrack_pytorch/data/kubric_data.py +243 -0
- locotrack_pytorch/environment.yml +151 -0
- locotrack_pytorch/experiment.py +238 -0
- locotrack_pytorch/model_utils.py +165 -0
- locotrack_pytorch/models/cmdtop.py +45 -0
- locotrack_pytorch/models/locotrack_model.py +1053 -0
- locotrack_pytorch/models/nets.py +429 -0
- locotrack_pytorch/models/utils.py +344 -0
- requirements.txt +7 -0
- viz_utils.py +104 -0
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
app.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import mediapy
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import matplotlib
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from locotrack_pytorch.models.locotrack_model import load_model
|
13 |
+
from viz_utils import paint_point_track
|
14 |
+
|
15 |
+
|
16 |
+
PREVIEW_WIDTH = 768 # Width of the preview video
|
17 |
+
VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
|
18 |
+
POINT_SIZE = 4 # Size of the query point in the preview video
|
19 |
+
FRAME_LIMIT = 300 # Limit the number of frames to process
|
20 |
+
|
21 |
+
|
22 |
+
def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
|
23 |
+
print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
|
24 |
+
|
25 |
+
current_frame = video_queried_preview[int(frame_num)]
|
26 |
+
|
27 |
+
# Get the mouse click
|
28 |
+
query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num))
|
29 |
+
|
30 |
+
# Choose the color for the point from matplotlib colormap
|
31 |
+
color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20)
|
32 |
+
color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
|
33 |
+
print(f"Color: {color}")
|
34 |
+
query_points_color[int(frame_num)].append(color)
|
35 |
+
|
36 |
+
# Draw the point on the frame
|
37 |
+
x, y = evt.index
|
38 |
+
current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1)
|
39 |
+
|
40 |
+
# Update the frame
|
41 |
+
video_queried_preview[int(frame_num)] = current_frame_draw
|
42 |
+
|
43 |
+
# Update the query count
|
44 |
+
query_count += 1
|
45 |
+
return (
|
46 |
+
current_frame_draw, # Updated frame for preview
|
47 |
+
video_queried_preview, # Updated preview video
|
48 |
+
query_points, # Updated query points
|
49 |
+
query_points_color, # Updated query points color
|
50 |
+
query_count # Updated query count
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
|
55 |
+
if len(query_points[int(frame_num)]) == 0:
|
56 |
+
return (
|
57 |
+
video_queried_preview[int(frame_num)],
|
58 |
+
video_queried_preview,
|
59 |
+
query_points,
|
60 |
+
query_points_color,
|
61 |
+
query_count
|
62 |
+
)
|
63 |
+
|
64 |
+
# Get the last point
|
65 |
+
query_points[int(frame_num)].pop(-1)
|
66 |
+
query_points_color[int(frame_num)].pop(-1)
|
67 |
+
|
68 |
+
# Redraw the frame
|
69 |
+
current_frame_draw = video_preview[int(frame_num)].copy()
|
70 |
+
for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
|
71 |
+
x, y, _ = point
|
72 |
+
current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
|
73 |
+
|
74 |
+
# Update the query count
|
75 |
+
query_count -= 1
|
76 |
+
|
77 |
+
# Update the frame
|
78 |
+
video_queried_preview[int(frame_num)] = current_frame_draw
|
79 |
+
return (
|
80 |
+
current_frame_draw, # Updated frame for preview
|
81 |
+
video_queried_preview, # Updated preview video
|
82 |
+
query_points, # Updated query points
|
83 |
+
query_points_color, # Updated query points color
|
84 |
+
query_count # Updated query count
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
|
89 |
+
query_count -= len(query_points[int(frame_num)])
|
90 |
+
|
91 |
+
query_points[int(frame_num)] = []
|
92 |
+
query_points_color[int(frame_num)] = []
|
93 |
+
|
94 |
+
video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
|
95 |
+
|
96 |
+
return (
|
97 |
+
video_preview[int(frame_num)], # Set the preview frame to the original frame
|
98 |
+
video_queried_preview,
|
99 |
+
query_points, # Cleared query points
|
100 |
+
query_points_color, # Cleared query points color
|
101 |
+
query_count # New query count
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def clear_all_fn(frame_num, video_preview):
|
107 |
+
return (
|
108 |
+
video_preview[int(frame_num)],
|
109 |
+
video_preview.copy(),
|
110 |
+
[[] for _ in range(len(video_preview))],
|
111 |
+
[[] for _ in range(len(video_preview))],
|
112 |
+
0
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def choose_frame(frame_num, video_preview_array):
|
117 |
+
return video_preview_array[int(frame_num)]
|
118 |
+
|
119 |
+
|
120 |
+
def extract_feature(video_input, model_size="small"):
|
121 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
122 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
123 |
+
|
124 |
+
model = load_model(model_size=model_size).to(device)
|
125 |
+
|
126 |
+
video_input = (video_input / 255.0) * 2 - 1
|
127 |
+
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
128 |
+
|
129 |
+
with torch.autocast(device_type=device, dtype=dtype):
|
130 |
+
with torch.no_grad():
|
131 |
+
feature = model.get_feature_grids(video_input)
|
132 |
+
|
133 |
+
return feature
|
134 |
+
|
135 |
+
|
136 |
+
def preprocess_video_input(video_path, model_size):
|
137 |
+
video_arr = mediapy.read_video(video_path)
|
138 |
+
video_fps = video_arr.metadata.fps
|
139 |
+
num_frames = video_arr.shape[0]
|
140 |
+
if num_frames > FRAME_LIMIT:
|
141 |
+
gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5)
|
142 |
+
video_arr = video_arr[:FRAME_LIMIT]
|
143 |
+
num_frames = FRAME_LIMIT
|
144 |
+
|
145 |
+
# Resize to preview size for faster processing, width = PREVIEW_WIDTH
|
146 |
+
height, width = video_arr.shape[1:3]
|
147 |
+
new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
|
148 |
+
|
149 |
+
preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
|
150 |
+
input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
|
151 |
+
|
152 |
+
preview_video = np.array(preview_video)
|
153 |
+
input_video = np.array(input_video)
|
154 |
+
|
155 |
+
video_feature = extract_feature(input_video, model_size)
|
156 |
+
|
157 |
+
return (
|
158 |
+
video_arr, # Original video
|
159 |
+
preview_video, # Original preview video, resized for faster processing
|
160 |
+
preview_video.copy(), # Copy of preview video for visualization
|
161 |
+
input_video, # Resized video input for model
|
162 |
+
video_feature, # Extracted feature
|
163 |
+
video_fps, # Set the video FPS
|
164 |
+
gr.update(open=False), # Close the video input drawer
|
165 |
+
model_size, # Set the model size
|
166 |
+
preview_video[0], # Set the preview frame to the first frame
|
167 |
+
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True), # Set slider interactive
|
168 |
+
[[] for _ in range(num_frames)], # Set query_points to empty
|
169 |
+
[[] for _ in range(num_frames)], # Set query_points_color to empty
|
170 |
+
[[] for _ in range(num_frames)],
|
171 |
+
0, # Set query count to 0
|
172 |
+
gr.update(interactive=True), # Make the buttons interactive
|
173 |
+
gr.update(interactive=True),
|
174 |
+
gr.update(interactive=True),
|
175 |
+
gr.update(interactive=True),
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
def track(
|
180 |
+
model_size,
|
181 |
+
video_preview,
|
182 |
+
video_input,
|
183 |
+
video_feature,
|
184 |
+
video_fps,
|
185 |
+
query_points,
|
186 |
+
query_points_color,
|
187 |
+
query_count,
|
188 |
+
):
|
189 |
+
if query_count == 0:
|
190 |
+
gr.Warning("Please add query points before tracking.", duration=5)
|
191 |
+
return None
|
192 |
+
|
193 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
194 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
195 |
+
|
196 |
+
# Convert query points to tensor, normalize to input resolution
|
197 |
+
query_points_tensor = []
|
198 |
+
for frame_points in query_points:
|
199 |
+
query_points_tensor.extend(frame_points)
|
200 |
+
|
201 |
+
query_points_tensor = torch.tensor(query_points_tensor).float()
|
202 |
+
query_points_tensor *= torch.tensor([
|
203 |
+
VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1
|
204 |
+
]) / torch.tensor([
|
205 |
+
[video_preview.shape[2], video_preview.shape[1], 1]
|
206 |
+
])
|
207 |
+
query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) # xyt -> tyx
|
208 |
+
|
209 |
+
# Preprocess video input
|
210 |
+
video_input = (video_input / 255.0) * 2 - 1
|
211 |
+
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
212 |
+
|
213 |
+
model = load_model(model_size=model_size).to(device)
|
214 |
+
with torch.autocast(device_type=device, dtype=dtype):
|
215 |
+
with torch.no_grad():
|
216 |
+
output = model(video_input, query_points_tensor, feature_grids=video_feature)
|
217 |
+
|
218 |
+
tracks = output['tracks'][0].cpu()
|
219 |
+
tracks = tracks * torch.tensor([
|
220 |
+
video_preview.shape[2], video_preview.shape[1]
|
221 |
+
]) / torch.tensor([
|
222 |
+
VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]
|
223 |
+
])
|
224 |
+
tracks = tracks.numpy()
|
225 |
+
|
226 |
+
|
227 |
+
occlusion_logits = output['occlusion']
|
228 |
+
pred_occ = torch.sigmoid(occlusion_logits)
|
229 |
+
if 'expected_dist' in output:
|
230 |
+
expected_dist = output['expected_dist']
|
231 |
+
pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(expected_dist))
|
232 |
+
|
233 |
+
pred_occ = (pred_occ > 0.5)[0].cpu().numpy()
|
234 |
+
|
235 |
+
# make color array
|
236 |
+
colors = []
|
237 |
+
for frame_colors in query_points_color:
|
238 |
+
colors.extend(frame_colors)
|
239 |
+
colors = np.array(colors)
|
240 |
+
|
241 |
+
painted_video = paint_point_track(
|
242 |
+
video_preview,
|
243 |
+
tracks,
|
244 |
+
~pred_occ,
|
245 |
+
colors,
|
246 |
+
)
|
247 |
+
|
248 |
+
# save video
|
249 |
+
video_file_name = uuid.uuid4().hex + ".mp4"
|
250 |
+
video_path = os.path.join(os.path.dirname(__file__), "tmp")
|
251 |
+
video_file_path = os.path.join(video_path, video_file_name)
|
252 |
+
os.makedirs(video_path, exist_ok=True)
|
253 |
+
|
254 |
+
mediapy.write_video(video_file_path, painted_video, fps=video_fps)
|
255 |
+
|
256 |
+
return video_file_path
|
257 |
+
|
258 |
+
|
259 |
+
with gr.Blocks() as demo:
|
260 |
+
video = gr.State()
|
261 |
+
video_queried_preview = gr.State()
|
262 |
+
video_preview = gr.State()
|
263 |
+
video_input = gr.State()
|
264 |
+
video_feautre = gr.State()
|
265 |
+
video_fps = gr.State(24)
|
266 |
+
model_size = gr.State("small")
|
267 |
+
|
268 |
+
query_points = gr.State([])
|
269 |
+
query_points_color = gr.State([])
|
270 |
+
is_tracked_query = gr.State([])
|
271 |
+
query_count = gr.State(0)
|
272 |
+
|
273 |
+
gr.Markdown("# LocoTrack Demo")
|
274 |
+
gr.Markdown("This is an interactive demo for LocoTrack. For more details, please refer to the [GitHub repository](https://github.com/KU-CVLAB/LocoTrack) or the [paper](https://arxiv.org/abs/2407.15420).")
|
275 |
+
|
276 |
+
gr.Markdown("## First step: Choose the model size and upload your video")
|
277 |
+
with gr.Row():
|
278 |
+
with gr.Accordion("Your video input", open=True) as video_in_drawer:
|
279 |
+
model_size_selection = gr.Radio(
|
280 |
+
label="Model Size",
|
281 |
+
choices=["small", "base"],
|
282 |
+
value="small",
|
283 |
+
)
|
284 |
+
video_in = gr.Video(label="Video Input", format="mp4")
|
285 |
+
|
286 |
+
gr.Markdown("## Second step: Add query points to track")
|
287 |
+
with gr.Row():
|
288 |
+
|
289 |
+
with gr.Column():
|
290 |
+
with gr.Row():
|
291 |
+
query_frames = gr.Slider(
|
292 |
+
minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
|
293 |
+
with gr.Row():
|
294 |
+
undo = gr.Button("Undo", interactive=False)
|
295 |
+
clear_frame = gr.Button("Clear Frame", interactive=False)
|
296 |
+
clear_all = gr.Button("Clear All", interactive=False)
|
297 |
+
|
298 |
+
with gr.Row():
|
299 |
+
current_frame = gr.Image(
|
300 |
+
label="Click to add query points",
|
301 |
+
type="numpy",
|
302 |
+
interactive=False
|
303 |
+
)
|
304 |
+
|
305 |
+
with gr.Row():
|
306 |
+
track_button = gr.Button("Track", interactive=False)
|
307 |
+
|
308 |
+
with gr.Column():
|
309 |
+
output_video = gr.Video(
|
310 |
+
label="Output Video",
|
311 |
+
interactive=False,
|
312 |
+
autoplay=True,
|
313 |
+
loop=True,
|
314 |
+
)
|
315 |
+
|
316 |
+
video_in.upload(
|
317 |
+
fn = preprocess_video_input,
|
318 |
+
inputs = [video_in, model_size_selection],
|
319 |
+
outputs = [
|
320 |
+
video,
|
321 |
+
video_preview,
|
322 |
+
video_queried_preview,
|
323 |
+
video_input,
|
324 |
+
video_feautre,
|
325 |
+
video_fps,
|
326 |
+
video_in_drawer,
|
327 |
+
model_size,
|
328 |
+
current_frame,
|
329 |
+
query_frames,
|
330 |
+
query_points,
|
331 |
+
query_points_color,
|
332 |
+
is_tracked_query,
|
333 |
+
query_count,
|
334 |
+
undo,
|
335 |
+
clear_frame,
|
336 |
+
clear_all,
|
337 |
+
track_button,
|
338 |
+
],
|
339 |
+
queue = False
|
340 |
+
)
|
341 |
+
|
342 |
+
query_frames.change(
|
343 |
+
fn = choose_frame,
|
344 |
+
inputs = [query_frames, video_queried_preview],
|
345 |
+
outputs = [
|
346 |
+
current_frame,
|
347 |
+
],
|
348 |
+
queue = False
|
349 |
+
)
|
350 |
+
|
351 |
+
current_frame.select(
|
352 |
+
fn = get_point,
|
353 |
+
inputs = [
|
354 |
+
query_frames,
|
355 |
+
video_queried_preview,
|
356 |
+
query_points,
|
357 |
+
query_points_color,
|
358 |
+
query_count,
|
359 |
+
],
|
360 |
+
outputs = [
|
361 |
+
current_frame,
|
362 |
+
video_queried_preview,
|
363 |
+
query_points,
|
364 |
+
query_points_color,
|
365 |
+
query_count
|
366 |
+
],
|
367 |
+
queue = False
|
368 |
+
)
|
369 |
+
|
370 |
+
undo.click(
|
371 |
+
fn = undo_point,
|
372 |
+
inputs = [
|
373 |
+
query_frames,
|
374 |
+
video_preview,
|
375 |
+
video_queried_preview,
|
376 |
+
query_points,
|
377 |
+
query_points_color,
|
378 |
+
query_count
|
379 |
+
],
|
380 |
+
outputs = [
|
381 |
+
current_frame,
|
382 |
+
video_queried_preview,
|
383 |
+
query_points,
|
384 |
+
query_points_color,
|
385 |
+
query_count
|
386 |
+
],
|
387 |
+
queue = False
|
388 |
+
)
|
389 |
+
|
390 |
+
clear_frame.click(
|
391 |
+
fn = clear_frame_fn,
|
392 |
+
inputs = [
|
393 |
+
query_frames,
|
394 |
+
video_preview,
|
395 |
+
video_queried_preview,
|
396 |
+
query_points,
|
397 |
+
query_points_color,
|
398 |
+
query_count
|
399 |
+
],
|
400 |
+
outputs = [
|
401 |
+
current_frame,
|
402 |
+
video_queried_preview,
|
403 |
+
query_points,
|
404 |
+
query_points_color,
|
405 |
+
query_count
|
406 |
+
],
|
407 |
+
queue = False
|
408 |
+
)
|
409 |
+
|
410 |
+
clear_all.click(
|
411 |
+
fn = clear_all_fn,
|
412 |
+
inputs = [
|
413 |
+
query_frames,
|
414 |
+
video_preview,
|
415 |
+
],
|
416 |
+
outputs = [
|
417 |
+
current_frame,
|
418 |
+
video_queried_preview,
|
419 |
+
query_points,
|
420 |
+
query_points_color,
|
421 |
+
query_count
|
422 |
+
],
|
423 |
+
queue = False
|
424 |
+
)
|
425 |
+
|
426 |
+
track_button.click(
|
427 |
+
fn = track,
|
428 |
+
inputs = [
|
429 |
+
model_size,
|
430 |
+
video_preview,
|
431 |
+
video_input,
|
432 |
+
video_feautre,
|
433 |
+
video_fps,
|
434 |
+
query_points,
|
435 |
+
query_points_color,
|
436 |
+
query_count,
|
437 |
+
],
|
438 |
+
outputs = [
|
439 |
+
output_video,
|
440 |
+
],
|
441 |
+
queue = True,
|
442 |
+
)
|
443 |
+
|
444 |
+
demo.launch(show_api=False, show_error=True, debug=True)
|
locotrack_pytorch/README.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch Implementation of LocoTrack
|
2 |
+
|
3 |
+
## Preparing the Environment
|
4 |
+
|
5 |
+
```bash
|
6 |
+
git clone https://github.com/google-research/kubric.git
|
7 |
+
|
8 |
+
conda create -n locotrack-pytorch python=3.11
|
9 |
+
conda activate locotrack-pytorch
|
10 |
+
|
11 |
+
pip install torch torchvision torchaudio lightning==2.3.3 tensorflow_datasets tensorflow matplotlib mediapy tensorflow_graphics einshape wandb
|
12 |
+
```
|
13 |
+
|
14 |
+
## LocoTrack Evaluation
|
15 |
+
|
16 |
+
### 1. Download Pre-trained Weights
|
17 |
+
|
18 |
+
To evaluate LocoTrack on the benchmarks, first download the pre-trained weights.
|
19 |
+
|
20 |
+
| Model | Pre-trained Weights |
|
21 |
+
|-------------|---------------------|
|
22 |
+
| LocoTrack-S | [Link](https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_small.ckpt) |
|
23 |
+
| LocoTrack-B | [Link](https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_base.ckpt) |
|
24 |
+
|
25 |
+
### 2. Adjust the Config File
|
26 |
+
|
27 |
+
In `config/default.ini` (or any other config file), add the path to the evaluation datasets to `[TRAINING]-val_dataset_path`. Additionally, adjust the model size for evaluation in `[MODEL]-model_kwargs-model_size`.
|
28 |
+
|
29 |
+
### 3. Run Evaluation
|
30 |
+
|
31 |
+
To evaluate the LocoTrack model, use the `experiment.py` script with the following command-line arguments:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
python experiment.py --config config/default.ini --mode eval_{dataset_to_eval_1}_..._{dataset_to_eval_N}[_q_first] --ckpt_path /path/to/checkpoint --save_path ./path_to_save_checkpoints/
|
35 |
+
```
|
36 |
+
|
37 |
+
- `--config`: Specifies the path to the configuration file. Default is `config/default.ini`.
|
38 |
+
- `--mode`: Specifies the mode to run the script. Use `eval` to perform evaluation. You can also include additional options for query first mode (`q_first`), and the name of the evaluation datasets. For example:
|
39 |
+
- Evaluation of the DAVIS dataset: `eval_davis`
|
40 |
+
- Evaluation of DAVIS and RoboTAP in query first mode: `eval_davis_robotap_q_first`
|
41 |
+
- `--ckpt_path`: Specifies the path to the checkpoint file. If not provided, the script will use the default checkpoint.
|
42 |
+
- `--save_path`: Specifies the path to save logs.
|
43 |
+
|
44 |
+
Replace `/path/to/checkpoint` with the actual path to your checkpoint file. This command will run the evaluation process and save the results in the specified `save_path`.
|
45 |
+
|
46 |
+
## LocoTrack Training
|
47 |
+
|
48 |
+
### Training Dataset Preparation
|
49 |
+
|
50 |
+
Download the panning-MOVi-E dataset used for training (approximately 273GB) from Huggingface using the following script. Git LFS should be installed to download the dataset. To install Git LFS, please refer to this [link](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage?platform=linux). Additionally, downloading instructions for the Huggingface dataset are available at this [link](https://huggingface.co/docs/hub/en/datasets-downloading).
|
51 |
+
|
52 |
+
```bash
|
53 |
+
git clone [email protected]:datasets/hamacojr/LocoTrack-panning-MOVi-E
|
54 |
+
```
|
55 |
+
|
56 |
+
### Training Script
|
57 |
+
|
58 |
+
Add the path to the downloaded panning-MOVi-E to the `[TRAINING]-kubric_dir` entry in `config/default.ini` (or any other config file). Optionally, for efficient training, change `[TRAINING]-precision` in the config file to `bf16-mixed` to use `bfloat16`. Then, run the training with the following script:
|
59 |
+
|
60 |
+
```bash
|
61 |
+
python experiment.py --config config/default.ini --mode train_davis --save_path ./path_to_save_checkpoints/
|
62 |
+
```
|
locotrack_pytorch/config/default.ini
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[TRAINING]
|
2 |
+
val_dataset_path = {"davis": "/home/seokjuc/sensei-fs-link/tapvid/tapvid_davis/tapvid_davis.pkl", "robotics": "", "kinetics": "", "robotap": ""}
|
3 |
+
kubric_dir = ./kubric
|
4 |
+
precision = 32
|
5 |
+
batch_size = 4
|
6 |
+
val_check_interval = 1000
|
7 |
+
log_every_n_steps = 5
|
8 |
+
gradient_clip_val = 1.0
|
9 |
+
max_steps = 300000
|
10 |
+
|
11 |
+
[MODEL]
|
12 |
+
model_kwargs = {"model_size": "base", "num_pips_iter": 4}
|
13 |
+
model_forward_kwargs = {"refinement_resolutions": ((256, 256),), "query_chunk_size": 256}
|
14 |
+
|
15 |
+
[LOSS]
|
16 |
+
loss_name = tapir_loss
|
17 |
+
loss_kwargs = {}
|
18 |
+
|
19 |
+
[OPTIMIZER]
|
20 |
+
optimizer_name = AdamW
|
21 |
+
optimizer_kwargs = {"lr": 1e-3, "weight_decay": 1e-3, "betas": (0.9, 0.95)}
|
22 |
+
|
23 |
+
[SCHEDULER]
|
24 |
+
scheduler_name = OneCycleLR
|
25 |
+
scheduler_kwargs = {"max_lr": 1e-3, "pct_start": 0.003, "total_steps": 300000}
|
locotrack_pytorch/data/evaluation_datasets.py
ADDED
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Evaluation dataset creation functions."""
|
17 |
+
|
18 |
+
import csv
|
19 |
+
import functools
|
20 |
+
import io
|
21 |
+
import os
|
22 |
+
from os import path
|
23 |
+
import pickle
|
24 |
+
import random
|
25 |
+
from typing import Iterable, Mapping, Optional, Tuple, Union
|
26 |
+
|
27 |
+
from absl import logging
|
28 |
+
|
29 |
+
import mediapy as media
|
30 |
+
import numpy as np
|
31 |
+
from PIL import Image
|
32 |
+
import scipy.io as sio
|
33 |
+
import tensorflow as tf
|
34 |
+
import tensorflow_datasets as tfds
|
35 |
+
|
36 |
+
from models.utils import convert_grid_coordinates
|
37 |
+
|
38 |
+
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
39 |
+
|
40 |
+
|
41 |
+
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
42 |
+
"""Resize a video to output_size."""
|
43 |
+
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
44 |
+
# such as a jitted jax.image.resize. It will make things faster.
|
45 |
+
return media.resize_video(video, output_size)
|
46 |
+
|
47 |
+
|
48 |
+
def compute_tapvid_metrics(
|
49 |
+
query_points: np.ndarray,
|
50 |
+
gt_occluded: np.ndarray,
|
51 |
+
gt_tracks: np.ndarray,
|
52 |
+
pred_occluded: np.ndarray,
|
53 |
+
pred_tracks: np.ndarray,
|
54 |
+
query_mode: str,
|
55 |
+
get_trackwise_metrics: bool = False,
|
56 |
+
) -> Mapping[str, np.ndarray]:
|
57 |
+
"""Computes TAP-Vid metrics (Jaccard, Pts.
|
58 |
+
|
59 |
+
Within Thresh, Occ.
|
60 |
+
|
61 |
+
Acc.)
|
62 |
+
|
63 |
+
See the TAP-Vid paper for details on the metric computation. All inputs are
|
64 |
+
given in raster coordinates. The first three arguments should be the direct
|
65 |
+
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
66 |
+
The paper metrics assume these are scaled relative to 256x256 images.
|
67 |
+
pred_occluded and pred_tracks are your algorithm's predictions.
|
68 |
+
|
69 |
+
This function takes a batch of inputs, and computes metrics separately for
|
70 |
+
each video. The metrics for the full benchmark are a simple mean of the
|
71 |
+
metrics across the full set of videos. These numbers are between 0 and 1,
|
72 |
+
but the paper multiplies them by 100 to ease reading.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
query_points: The query points, an in the format [t, y, x]. Its size is
|
76 |
+
[b, n, 3], where b is the batch size and n is the number of queries
|
77 |
+
gt_occluded: A boolean array of shape [b, n, t], where t is the number of
|
78 |
+
frames. True indicates that the point is occluded.
|
79 |
+
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is in the
|
80 |
+
format [x, y]
|
81 |
+
pred_occluded: A boolean array of predicted occlusions, in the same format
|
82 |
+
as gt_occluded.
|
83 |
+
pred_tracks: An array of track predictions from your algorithm, in the same
|
84 |
+
format as gt_tracks.
|
85 |
+
query_mode: Either 'first' or 'strided', depending on how queries are
|
86 |
+
sampled. If 'first', we assume the prior knowledge that all points
|
87 |
+
before the query point are occluded, and these are removed from the
|
88 |
+
evaluation.
|
89 |
+
get_trackwise_metrics: if True, the metrics will be computed for every
|
90 |
+
track (rather than every video, which is the default). This means
|
91 |
+
every output tensor will have an extra axis [batch, num_tracks] rather
|
92 |
+
than simply (batch).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
A dict with the following keys:
|
96 |
+
|
97 |
+
occlusion_accuracy: Accuracy at predicting occlusion.
|
98 |
+
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
99 |
+
predicted to be within the given pixel threshold, ignoring occlusion
|
100 |
+
prediction.
|
101 |
+
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
102 |
+
threshold
|
103 |
+
average_pts_within_thresh: average across pts_within_{x}
|
104 |
+
average_jaccard: average across jaccard_{x}
|
105 |
+
"""
|
106 |
+
|
107 |
+
summing_axis = (2,) if get_trackwise_metrics else (1, 2)
|
108 |
+
|
109 |
+
metrics = {}
|
110 |
+
|
111 |
+
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
112 |
+
if query_mode == 'first':
|
113 |
+
# evaluate frames after the query frame
|
114 |
+
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
115 |
+
elif query_mode == 'strided':
|
116 |
+
# evaluate all frames except the query frame
|
117 |
+
query_frame_to_eval_frames = 1 - eye
|
118 |
+
else:
|
119 |
+
raise ValueError('Unknown query mode ' + query_mode)
|
120 |
+
|
121 |
+
query_frame = query_points[..., 0]
|
122 |
+
query_frame = np.round(query_frame).astype(np.int32)
|
123 |
+
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
124 |
+
|
125 |
+
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
126 |
+
# ground truth.
|
127 |
+
occ_acc = np.sum(
|
128 |
+
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
129 |
+
axis=summing_axis,
|
130 |
+
) / np.sum(evaluation_points, axis=summing_axis)
|
131 |
+
metrics['occlusion_accuracy'] = occ_acc
|
132 |
+
|
133 |
+
# Next, convert the predictions and ground truth positions into pixel
|
134 |
+
# coordinates.
|
135 |
+
visible = np.logical_not(gt_occluded)
|
136 |
+
pred_visible = np.logical_not(pred_occluded)
|
137 |
+
all_frac_within = []
|
138 |
+
all_jaccard = []
|
139 |
+
for thresh in [1, 2, 4, 8, 16]:
|
140 |
+
# True positives are points that are within the threshold and where both
|
141 |
+
# the prediction and the ground truth are listed as visible.
|
142 |
+
within_dist = np.sum(
|
143 |
+
np.square(pred_tracks - gt_tracks),
|
144 |
+
axis=-1,
|
145 |
+
) < np.square(thresh)
|
146 |
+
is_correct = np.logical_and(within_dist, visible)
|
147 |
+
|
148 |
+
# Compute the frac_within_threshold, which is the fraction of points
|
149 |
+
# within the threshold among points that are visible in the ground truth,
|
150 |
+
# ignoring whether they're predicted to be visible.
|
151 |
+
count_correct = np.sum(
|
152 |
+
is_correct & evaluation_points,
|
153 |
+
axis=summing_axis,
|
154 |
+
)
|
155 |
+
count_visible_points = np.sum(
|
156 |
+
visible & evaluation_points, axis=summing_axis
|
157 |
+
)
|
158 |
+
frac_correct = count_correct / count_visible_points
|
159 |
+
metrics['pts_within_' + str(thresh)] = frac_correct
|
160 |
+
all_frac_within.append(frac_correct)
|
161 |
+
|
162 |
+
true_positives = np.sum(
|
163 |
+
is_correct & pred_visible & evaluation_points, axis=summing_axis
|
164 |
+
)
|
165 |
+
|
166 |
+
# The denominator of the jaccard metric is the true positives plus
|
167 |
+
# false positives plus false negatives. However, note that true positives
|
168 |
+
# plus false negatives is simply the number of points in the ground truth
|
169 |
+
# which is easier to compute than trying to compute all three quantities.
|
170 |
+
# Thus we just add the number of points in the ground truth to the number
|
171 |
+
# of false positives.
|
172 |
+
#
|
173 |
+
# False positives are simply points that are predicted to be visible,
|
174 |
+
# but the ground truth is not visible or too far from the prediction.
|
175 |
+
gt_positives = np.sum(visible & evaluation_points, axis=summing_axis)
|
176 |
+
false_positives = (~visible) & pred_visible
|
177 |
+
false_positives = false_positives | ((~within_dist) & pred_visible)
|
178 |
+
false_positives = np.sum(
|
179 |
+
false_positives & evaluation_points, axis=summing_axis
|
180 |
+
)
|
181 |
+
jaccard = true_positives / (gt_positives + false_positives)
|
182 |
+
metrics['jaccard_' + str(thresh)] = jaccard
|
183 |
+
all_jaccard.append(jaccard)
|
184 |
+
metrics['average_jaccard'] = np.mean(
|
185 |
+
np.stack(all_jaccard, axis=1),
|
186 |
+
axis=1,
|
187 |
+
)
|
188 |
+
metrics['average_pts_within_thresh'] = np.mean(
|
189 |
+
np.stack(all_frac_within, axis=1),
|
190 |
+
axis=1,
|
191 |
+
)
|
192 |
+
return metrics
|
193 |
+
|
194 |
+
|
195 |
+
def latex_table(mean_scalars: Mapping[str, float]) -> str:
|
196 |
+
"""Generate a latex table for displaying TAP-Vid and PCK metrics."""
|
197 |
+
if 'average_jaccard' in mean_scalars:
|
198 |
+
latex_fields = [
|
199 |
+
'average_jaccard',
|
200 |
+
'average_pts_within_thresh',
|
201 |
+
'occlusion_accuracy',
|
202 |
+
'jaccard_1',
|
203 |
+
'jaccard_2',
|
204 |
+
'jaccard_4',
|
205 |
+
'jaccard_8',
|
206 |
+
'jaccard_16',
|
207 |
+
'pts_within_1',
|
208 |
+
'pts_within_2',
|
209 |
+
'pts_within_4',
|
210 |
+
'pts_within_8',
|
211 |
+
'pts_within_16',
|
212 |
+
]
|
213 |
+
header = (
|
214 |
+
'AJ & $<\\delta^{x}_{avg}$ & OA & Jac. $\\delta^{0}$ & '
|
215 |
+
+ 'Jac. $\\delta^{1}$ & Jac. $\\delta^{2}$ & '
|
216 |
+
+ 'Jac. $\\delta^{3}$ & Jac. $\\delta^{4}$ & $<\\delta^{0}$ & '
|
217 |
+
+ '$<\\delta^{1}$ & $<\\delta^{2}$ & $<\\delta^{3}$ & '
|
218 |
+
+ '$<\\delta^{4}$'
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
latex_fields = ['[email protected]', '[email protected]', '[email protected]', '[email protected]', '[email protected]']
|
222 |
+
header = ' & '.join(latex_fields)
|
223 |
+
|
224 |
+
body = ' & '.join(
|
225 |
+
[f'{float(np.array(mean_scalars[x]*100)):.3}' for x in latex_fields]
|
226 |
+
)
|
227 |
+
return '\n'.join([header, body])
|
228 |
+
|
229 |
+
|
230 |
+
def sample_queries_strided(
|
231 |
+
target_occluded: np.ndarray,
|
232 |
+
target_points: np.ndarray,
|
233 |
+
frames: np.ndarray,
|
234 |
+
query_stride: int = 5,
|
235 |
+
) -> Mapping[str, np.ndarray]:
|
236 |
+
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
237 |
+
|
238 |
+
Given a set of frames and tracks with no query points, sample queries
|
239 |
+
strided every query_stride frames, ignoring points that are not visible
|
240 |
+
at the selected frames.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
244 |
+
where True indicates occluded.
|
245 |
+
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
246 |
+
is [x,y] scaled between 0 and 1.
|
247 |
+
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
248 |
+
-1 and 1.
|
249 |
+
query_stride: When sampling query points, search for un-occluded points
|
250 |
+
every query_stride frames and convert each one into a query.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
A dict with the keys:
|
254 |
+
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
255 |
+
has floats scaled to the range [-1, 1].
|
256 |
+
query_points: Query points of shape [1, n_queries, 3] where
|
257 |
+
each point is [t, y, x] scaled to the range [-1, 1].
|
258 |
+
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
259 |
+
each point is [x, y] scaled to the range [-1, 1].
|
260 |
+
trackgroup: Index of the original track that each query point was
|
261 |
+
sampled from. This is useful for visualization.
|
262 |
+
"""
|
263 |
+
tracks = []
|
264 |
+
occs = []
|
265 |
+
queries = []
|
266 |
+
trackgroups = []
|
267 |
+
total = 0
|
268 |
+
trackgroup = np.arange(target_occluded.shape[0])
|
269 |
+
for i in range(0, target_occluded.shape[1], query_stride):
|
270 |
+
mask = target_occluded[:, i] == 0
|
271 |
+
query = np.stack(
|
272 |
+
[
|
273 |
+
i * np.ones(target_occluded.shape[0:1]),
|
274 |
+
target_points[:, i, 1],
|
275 |
+
target_points[:, i, 0],
|
276 |
+
],
|
277 |
+
axis=-1,
|
278 |
+
)
|
279 |
+
queries.append(query[mask])
|
280 |
+
tracks.append(target_points[mask])
|
281 |
+
occs.append(target_occluded[mask])
|
282 |
+
trackgroups.append(trackgroup[mask])
|
283 |
+
total += np.array(np.sum(target_occluded[:, i] == 0))
|
284 |
+
|
285 |
+
return {
|
286 |
+
'video': frames[np.newaxis, ...],
|
287 |
+
'query_points': np.concatenate(queries, axis=0)[np.newaxis, ...],
|
288 |
+
'target_points': np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
289 |
+
'occluded': np.concatenate(occs, axis=0)[np.newaxis, ...],
|
290 |
+
'trackgroup': np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
291 |
+
}
|
292 |
+
|
293 |
+
|
294 |
+
def sample_queries_first(
|
295 |
+
target_occluded: np.ndarray,
|
296 |
+
target_points: np.ndarray,
|
297 |
+
frames: np.ndarray,
|
298 |
+
) -> Mapping[str, np.ndarray]:
|
299 |
+
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
300 |
+
|
301 |
+
Given a set of frames and tracks with no query points, use the first
|
302 |
+
visible point in each track as the query.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
306 |
+
where True indicates occluded.
|
307 |
+
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
308 |
+
is [x,y] scaled between 0 and 1.
|
309 |
+
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
310 |
+
-1 and 1.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
A dict with the keys:
|
314 |
+
video: Video tensor of shape [1, n_frames, height, width, 3]
|
315 |
+
query_points: Query points of shape [1, n_queries, 3] where
|
316 |
+
each point is [t, y, x] scaled to the range [-1, 1]
|
317 |
+
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
318 |
+
each point is [x, y] scaled to the range [-1, 1]
|
319 |
+
"""
|
320 |
+
|
321 |
+
valid = np.sum(~target_occluded, axis=1) > 0
|
322 |
+
target_points = target_points[valid, :]
|
323 |
+
target_occluded = target_occluded[valid, :]
|
324 |
+
|
325 |
+
query_points = []
|
326 |
+
for i in range(target_points.shape[0]):
|
327 |
+
index = np.where(target_occluded[i] == 0)[0][0]
|
328 |
+
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
329 |
+
query_points.append(np.array([index, y, x])) # [t, y, x]
|
330 |
+
query_points = np.stack(query_points, axis=0)
|
331 |
+
|
332 |
+
return {
|
333 |
+
'video': frames[np.newaxis, ...],
|
334 |
+
'query_points': query_points[np.newaxis, ...],
|
335 |
+
'target_points': target_points[np.newaxis, ...],
|
336 |
+
'occluded': target_occluded[np.newaxis, ...],
|
337 |
+
}
|
338 |
+
|
339 |
+
|
340 |
+
def create_jhmdb_dataset(
|
341 |
+
jhmdb_path: str, resolution: Optional[Tuple[int, int]] = (256, 256)
|
342 |
+
) -> Iterable[DatasetElement]:
|
343 |
+
"""JHMDB dataset, including fields required for PCK evaluation."""
|
344 |
+
videos = []
|
345 |
+
for file in tf.io.gfile.listdir(path.join(gt_dir, 'splits')):
|
346 |
+
# JHMDB file containing the first split, which is standard for this type of
|
347 |
+
# evaluation.
|
348 |
+
if not file.endswith('split1.txt'):
|
349 |
+
continue
|
350 |
+
|
351 |
+
video_folder = '_'.join(file.split('_')[:-2])
|
352 |
+
for video in tf.io.gfile.GFile(path.join(gt_dir, 'splits', file), 'r'):
|
353 |
+
video, traintest = video.split()
|
354 |
+
video, _ = video.split('.')
|
355 |
+
|
356 |
+
traintest = int(traintest)
|
357 |
+
video_path = path.join(video_folder, video)
|
358 |
+
|
359 |
+
if traintest == 2:
|
360 |
+
videos.append(video_path)
|
361 |
+
|
362 |
+
if not videos:
|
363 |
+
raise ValueError('No JHMDB videos found in directory ' + str(jhmdb_path))
|
364 |
+
|
365 |
+
# Shuffle so numbers converge faster.
|
366 |
+
random.shuffle(videos)
|
367 |
+
|
368 |
+
for video in videos:
|
369 |
+
logging.info(video)
|
370 |
+
joints = path.join(gt_dir, 'joint_positions', video, 'joint_positions.mat')
|
371 |
+
|
372 |
+
if not tf.io.gfile.exists(joints):
|
373 |
+
logging.info('skip %s', video)
|
374 |
+
continue
|
375 |
+
|
376 |
+
gt_pose = sio.loadmat(tf.io.gfile.GFile(joints, 'rb'))['pos_img']
|
377 |
+
gt_pose = np.transpose(gt_pose, [1, 2, 0])
|
378 |
+
frames = path.join(gt_dir, 'Rename_Images', video, '*.png')
|
379 |
+
framefil = tf.io.gfile.glob(frames)
|
380 |
+
framefil.sort()
|
381 |
+
|
382 |
+
def read_frame(f):
|
383 |
+
im = Image.open(tf.io.gfile.GFile(f, 'rb'))
|
384 |
+
im = im.convert('RGB')
|
385 |
+
im_data = np.array(im.getdata(), np.uint8)
|
386 |
+
return im_data.reshape([im.size[1], im.size[0], 3])
|
387 |
+
|
388 |
+
frames = [read_frame(x) for x in framefil]
|
389 |
+
frames = np.stack(frames)
|
390 |
+
height = frames.shape[1]
|
391 |
+
width = frames.shape[2]
|
392 |
+
invalid_x = np.logical_or(
|
393 |
+
gt_pose[:, 0:1, 0] < 0,
|
394 |
+
gt_pose[:, 0:1, 0] >= width,
|
395 |
+
)
|
396 |
+
invalid_y = np.logical_or(
|
397 |
+
gt_pose[:, 0:1, 1] < 0,
|
398 |
+
gt_pose[:, 0:1, 1] >= height,
|
399 |
+
)
|
400 |
+
invalid = np.logical_or(invalid_x, invalid_y)
|
401 |
+
invalid = np.tile(invalid, [1, gt_pose.shape[1]])
|
402 |
+
invalid = invalid[:, :, np.newaxis].astype(np.float32)
|
403 |
+
gt_pose_orig = gt_pose
|
404 |
+
|
405 |
+
if resolution is not None and resolution != frames.shape[1:3]:
|
406 |
+
frames = resize_video(frames, resolution)
|
407 |
+
frames = frames / (255.0 / 2.0) - 1.0
|
408 |
+
queries = gt_pose[:, 0]
|
409 |
+
queries = np.concatenate(
|
410 |
+
[queries[..., 0:1] * 0, queries[..., ::-1]],
|
411 |
+
axis=-1,
|
412 |
+
)
|
413 |
+
gt_pose = convert_grid_coordinates(
|
414 |
+
gt_pose,
|
415 |
+
np.array([width, height]),
|
416 |
+
np.array([frames.shape[2], frames.shape[1]]),
|
417 |
+
)
|
418 |
+
# Set invalid poses to -1 (outside the frame)
|
419 |
+
gt_pose = (1.0 - invalid) * gt_pose + invalid * (-1.0)
|
420 |
+
|
421 |
+
if gt_pose.shape[1] < frames.shape[0]:
|
422 |
+
# Some videos have pose sequences that are shorter than the frame
|
423 |
+
# sequence (usually because the person disappears). In this case,
|
424 |
+
# truncate the video.
|
425 |
+
logging.warning('short video!!')
|
426 |
+
frames = frames[: gt_pose.shape[1]]
|
427 |
+
|
428 |
+
converted = {
|
429 |
+
'video': frames[np.newaxis, ...],
|
430 |
+
'query_points': queries[np.newaxis, ...],
|
431 |
+
'target_points': gt_pose[np.newaxis, ...],
|
432 |
+
'gt_pose': gt_pose[np.newaxis, ...],
|
433 |
+
'gt_pose_orig': gt_pose_orig[np.newaxis, ...],
|
434 |
+
'occluded': gt_pose[np.newaxis, ..., 0] * 0,
|
435 |
+
'fname': video,
|
436 |
+
'im_size': np.array([height, width]),
|
437 |
+
}
|
438 |
+
yield {'jhmdb': converted}
|
439 |
+
|
440 |
+
|
441 |
+
def create_kubric_eval_train_dataset(
|
442 |
+
mode: str,
|
443 |
+
train_size: Tuple[int, int] = (256, 256),
|
444 |
+
max_dataset_size: int = 100,
|
445 |
+
) -> Iterable[DatasetElement]:
|
446 |
+
"""Dataset for evaluating performance on Kubric training data."""
|
447 |
+
|
448 |
+
# Lazy import kubric because requirements_inference doesn't include it.
|
449 |
+
from kubric.challenges.point_tracking import dataset
|
450 |
+
res = dataset.create_point_tracking_dataset(
|
451 |
+
split='train',
|
452 |
+
train_size=train_size,
|
453 |
+
batch_dims=[1],
|
454 |
+
shuffle_buffer_size=None,
|
455 |
+
repeat=False,
|
456 |
+
vflip='vflip' in mode,
|
457 |
+
random_crop=False,
|
458 |
+
)
|
459 |
+
np_ds = tfds.as_numpy(res)
|
460 |
+
|
461 |
+
num_returned = 0
|
462 |
+
for data in np_ds:
|
463 |
+
if num_returned >= max_dataset_size:
|
464 |
+
break
|
465 |
+
num_returned += 1
|
466 |
+
yield {'kubric': data}
|
467 |
+
|
468 |
+
|
469 |
+
def create_kubric_eval_dataset(
|
470 |
+
mode: str, train_size: Tuple[int, int] = (256, 256)
|
471 |
+
) -> Iterable[DatasetElement]:
|
472 |
+
"""Dataset for evaluating performance on Kubric val data."""
|
473 |
+
# Lazy import kubric because requirements_inference doesn't include it.
|
474 |
+
from kubric.challenges.point_tracking import dataset
|
475 |
+
res = dataset.create_point_tracking_dataset(
|
476 |
+
split='validation',
|
477 |
+
train_size=train_size,
|
478 |
+
batch_dims=[1],
|
479 |
+
shuffle_buffer_size=None,
|
480 |
+
repeat=False,
|
481 |
+
vflip='vflip' in mode,
|
482 |
+
random_crop=False,
|
483 |
+
)
|
484 |
+
np_ds = tfds.as_numpy(res)
|
485 |
+
|
486 |
+
for data in np_ds:
|
487 |
+
yield {'kubric': data}
|
488 |
+
|
489 |
+
|
490 |
+
def create_davis_dataset(
|
491 |
+
davis_points_path: str,
|
492 |
+
query_mode: str = 'strided',
|
493 |
+
full_resolution=False,
|
494 |
+
resolution: Optional[Tuple[int, int]] = (256, 256),
|
495 |
+
) -> Iterable[DatasetElement]:
|
496 |
+
"""Dataset for evaluating performance on DAVIS data."""
|
497 |
+
pickle_path = davis_points_path
|
498 |
+
|
499 |
+
with tf.io.gfile.GFile(pickle_path, 'rb') as f:
|
500 |
+
davis_points_dataset = pickle.load(f)
|
501 |
+
|
502 |
+
if full_resolution:
|
503 |
+
ds, _ = tfds.load(
|
504 |
+
'davis/full_resolution', split='validation', with_info=True
|
505 |
+
)
|
506 |
+
to_iterate = tfds.as_numpy(ds)
|
507 |
+
else:
|
508 |
+
to_iterate = davis_points_dataset.keys()
|
509 |
+
|
510 |
+
for tmp in to_iterate:
|
511 |
+
if full_resolution:
|
512 |
+
frames = tmp['video']['frames']
|
513 |
+
video_name = tmp['metadata']['video_name'].decode()
|
514 |
+
else:
|
515 |
+
video_name = tmp
|
516 |
+
frames = davis_points_dataset[video_name]['video']
|
517 |
+
if resolution is not None and resolution != frames.shape[1:3]:
|
518 |
+
frames = resize_video(frames, resolution)
|
519 |
+
|
520 |
+
frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
|
521 |
+
target_points = davis_points_dataset[video_name]['points']
|
522 |
+
target_occ = davis_points_dataset[video_name]['occluded']
|
523 |
+
target_points = target_points * np.array([frames.shape[2], frames.shape[1]])
|
524 |
+
|
525 |
+
if query_mode == 'strided':
|
526 |
+
converted = sample_queries_strided(target_occ, target_points, frames)
|
527 |
+
elif query_mode == 'first':
|
528 |
+
converted = sample_queries_first(target_occ, target_points, frames)
|
529 |
+
else:
|
530 |
+
raise ValueError(f'Unknown query mode {query_mode}.')
|
531 |
+
|
532 |
+
yield {'davis': converted}
|
533 |
+
|
534 |
+
|
535 |
+
def create_rgb_stacking_dataset(
|
536 |
+
robotics_points_path: str,
|
537 |
+
query_mode: str = 'strided',
|
538 |
+
resolution: Optional[Tuple[int, int]] = (256, 256),
|
539 |
+
) -> Iterable[DatasetElement]:
|
540 |
+
"""Dataset for evaluating performance on robotics data."""
|
541 |
+
pickle_path = robotics_points_path
|
542 |
+
|
543 |
+
with tf.io.gfile.GFile(pickle_path, 'rb') as f:
|
544 |
+
robotics_points_dataset = pickle.load(f)
|
545 |
+
|
546 |
+
for example in robotics_points_dataset:
|
547 |
+
frames = example['video']
|
548 |
+
if resolution is not None and resolution != frames.shape[1:3]:
|
549 |
+
frames = resize_video(frames, resolution)
|
550 |
+
frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
|
551 |
+
target_points = example['points']
|
552 |
+
target_occ = example['occluded']
|
553 |
+
target_points = target_points * np.array([frames.shape[2], frames.shape[1]])
|
554 |
+
|
555 |
+
if query_mode == 'strided':
|
556 |
+
converted = sample_queries_strided(target_occ, target_points, frames)
|
557 |
+
elif query_mode == 'first':
|
558 |
+
converted = sample_queries_first(target_occ, target_points, frames)
|
559 |
+
else:
|
560 |
+
raise ValueError(f'Unknown query mode {query_mode}.')
|
561 |
+
|
562 |
+
yield {'robotics': converted}
|
563 |
+
|
564 |
+
|
565 |
+
def create_kinetics_dataset(
|
566 |
+
kinetics_path: str, query_mode: str = 'strided',
|
567 |
+
resolution: Optional[Tuple[int, int]] = (256, 256),
|
568 |
+
) -> Iterable[DatasetElement]:
|
569 |
+
"""Dataset for evaluating performance on Kinetics point tracking."""
|
570 |
+
|
571 |
+
all_paths = tf.io.gfile.glob(path.join(kinetics_path, '*_of_0010.pkl'))
|
572 |
+
for pickle_path in all_paths:
|
573 |
+
with open(pickle_path, 'rb') as f:
|
574 |
+
data = pickle.load(f)
|
575 |
+
if isinstance(data, dict):
|
576 |
+
data = list(data.values())
|
577 |
+
|
578 |
+
# idx = random.randint(0, len(data) - 1)
|
579 |
+
for idx in range(len(data)):
|
580 |
+
example = data[idx]
|
581 |
+
|
582 |
+
frames = example['video']
|
583 |
+
|
584 |
+
if isinstance(frames[0], bytes):
|
585 |
+
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
586 |
+
def decode(frame):
|
587 |
+
byteio = io.BytesIO(frame)
|
588 |
+
img = Image.open(byteio)
|
589 |
+
return np.array(img)
|
590 |
+
|
591 |
+
frames = np.array([decode(frame) for frame in frames])
|
592 |
+
|
593 |
+
if resolution is not None and resolution != frames.shape[1:3]:
|
594 |
+
frames = resize_video(frames, resolution)
|
595 |
+
|
596 |
+
frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
|
597 |
+
target_points = example['points']
|
598 |
+
target_occ = example['occluded']
|
599 |
+
target_points *= np.array([frames.shape[2], frames.shape[1]])
|
600 |
+
|
601 |
+
if query_mode == 'strided':
|
602 |
+
converted = sample_queries_strided(target_occ, target_points, frames)
|
603 |
+
elif query_mode == 'first':
|
604 |
+
converted = sample_queries_first(target_occ, target_points, frames)
|
605 |
+
else:
|
606 |
+
raise ValueError(f'Unknown query mode {query_mode}.')
|
607 |
+
|
608 |
+
yield {'kinetics': converted}
|
609 |
+
|
610 |
+
|
611 |
+
def create_robotap_dataset(
|
612 |
+
robotics_points_path: str,
|
613 |
+
query_mode: str = 'strided',
|
614 |
+
resolution: Optional[Tuple[int, int]] = (256, 256),
|
615 |
+
) -> Iterable[DatasetElement]:
|
616 |
+
"""Dataset for evaluating performance on robotics data."""
|
617 |
+
pickle_path = robotics_points_path
|
618 |
+
|
619 |
+
# with tf.io.gfile.GFile(pickle_path, 'rb') as f:
|
620 |
+
# robotics_points_dataset = pickle.load(f)
|
621 |
+
robotics_points_dataset = []
|
622 |
+
all_paths = tf.io.gfile.glob(path.join(robotics_points_path, '*.pkl'))
|
623 |
+
for pickle_path in all_paths:
|
624 |
+
with open(pickle_path, 'rb') as f:
|
625 |
+
data = pickle.load(f)
|
626 |
+
robotics_points_dataset.extend(data.values())
|
627 |
+
|
628 |
+
for example in robotics_points_dataset:
|
629 |
+
frames = example['video']
|
630 |
+
if resolution is not None and resolution != frames.shape[1:3]:
|
631 |
+
frames = resize_video(frames, resolution)
|
632 |
+
frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
|
633 |
+
target_points = example['points']
|
634 |
+
target_occ = example['occluded']
|
635 |
+
target_points = target_points * np.array([frames.shape[2], frames.shape[1]])
|
636 |
+
|
637 |
+
if query_mode == 'strided':
|
638 |
+
converted = sample_queries_strided(target_occ, target_points, frames)
|
639 |
+
elif query_mode == 'first':
|
640 |
+
converted = sample_queries_first(target_occ, target_points, frames)
|
641 |
+
else:
|
642 |
+
raise ValueError(f'Unknown query mode {query_mode}.')
|
643 |
+
|
644 |
+
yield {'robotap': converted}
|
645 |
+
|
646 |
+
|
647 |
+
def create_csv_dataset(
|
648 |
+
dataset_name: str,
|
649 |
+
csv_path: str,
|
650 |
+
video_base_path: str,
|
651 |
+
query_mode: str = 'strided',
|
652 |
+
resolution: Optional[Tuple[int, int]] = (256, 256),
|
653 |
+
max_video_frames: Optional[int] = 1000,
|
654 |
+
) -> Iterable[DatasetElement]:
|
655 |
+
"""Create an evaluation iterator out of human annotations and videos.
|
656 |
+
|
657 |
+
Args:
|
658 |
+
dataset_name: Name to the dataset.
|
659 |
+
csv_path: Path to annotations csv.
|
660 |
+
video_base_path: Path to annotated videos.
|
661 |
+
query_mode: sample query points from first frame or strided.
|
662 |
+
resolution: The video resolution in (height, width).
|
663 |
+
max_video_frames: Max length of annotated video.
|
664 |
+
|
665 |
+
Yields:
|
666 |
+
Samples for evaluation.
|
667 |
+
"""
|
668 |
+
point_tracks_all = dict()
|
669 |
+
with tf.io.gfile.GFile(csv_path, 'r') as f:
|
670 |
+
reader = csv.reader(f, delimiter=',')
|
671 |
+
for row in reader:
|
672 |
+
video_id = row[0]
|
673 |
+
point_tracks = np.array(row[1:]).reshape(-1, 3)
|
674 |
+
if video_id in point_tracks_all:
|
675 |
+
point_tracks_all[video_id].append(point_tracks)
|
676 |
+
else:
|
677 |
+
point_tracks_all[video_id] = [point_tracks]
|
678 |
+
|
679 |
+
for video_id in point_tracks_all:
|
680 |
+
if video_id.endswith('.mp4'):
|
681 |
+
video_path = path.join(video_base_path, video_id)
|
682 |
+
else:
|
683 |
+
video_path = path.join(video_base_path, video_id + '.mp4')
|
684 |
+
frames = media.read_video(video_path)
|
685 |
+
if resolution is not None and resolution != frames.shape[1:3]:
|
686 |
+
frames = media.resize_video(frames, resolution)
|
687 |
+
frames = frames.astype(np.float32) / 255.0 * 2.0 - 1.0
|
688 |
+
|
689 |
+
point_tracks = np.stack(point_tracks_all[video_id], axis=0)
|
690 |
+
point_tracks = point_tracks.astype(np.float32)
|
691 |
+
if frames.shape[0] < point_tracks.shape[1]:
|
692 |
+
logging.info('Warning: short video!')
|
693 |
+
point_tracks = point_tracks[:, : frames.shape[0]]
|
694 |
+
point_tracks, occluded = point_tracks[..., 0:2], point_tracks[..., 2]
|
695 |
+
occluded = occluded > 0
|
696 |
+
target_points = point_tracks * np.array([frames.shape[2], frames.shape[1]])
|
697 |
+
|
698 |
+
num_splits = int(np.ceil(frames.shape[0] / max_video_frames))
|
699 |
+
if num_splits > 1:
|
700 |
+
print(f'Going to split the video {video_id} into {num_splits}')
|
701 |
+
for i in range(num_splits):
|
702 |
+
start_index = i * frames.shape[0] // num_splits
|
703 |
+
end_index = (i + 1) * frames.shape[0] // num_splits
|
704 |
+
sub_occluded = occluded[:, start_index:end_index]
|
705 |
+
sub_target_points = target_points[:, start_index:end_index]
|
706 |
+
sub_frames = frames[start_index:end_index]
|
707 |
+
|
708 |
+
if query_mode == 'strided':
|
709 |
+
converted = sample_queries_strided(
|
710 |
+
sub_occluded, sub_target_points, sub_frames
|
711 |
+
)
|
712 |
+
elif query_mode == 'first':
|
713 |
+
converted = sample_queries_first(
|
714 |
+
sub_occluded, sub_target_points, sub_frames
|
715 |
+
)
|
716 |
+
else:
|
717 |
+
raise ValueError(f'Unknown query mode {query_mode}.')
|
718 |
+
|
719 |
+
yield {dataset_name: converted}
|
720 |
+
|
721 |
+
|
722 |
+
import torch
|
723 |
+
from torch.utils.data import Dataset
|
724 |
+
|
725 |
+
|
726 |
+
class CustomDataset(Dataset):
|
727 |
+
def __init__(self, data_generator: Iterable[DatasetElement], key: str):
|
728 |
+
self.data = list(data_generator)
|
729 |
+
self.key = key
|
730 |
+
|
731 |
+
def __len__(self):
|
732 |
+
return len(self.data)
|
733 |
+
|
734 |
+
def __getitem__(self, idx):
|
735 |
+
data = self.data[idx][self.key]
|
736 |
+
data = {k: torch.tensor(v)[0] if isinstance(v, np.ndarray) else v for k, v in data.items()}
|
737 |
+
# Convert double to float
|
738 |
+
data = {k: v.float() if v.dtype == torch.float64 else v for k, v in data.items()}
|
739 |
+
return data
|
740 |
+
|
741 |
+
|
742 |
+
def get_eval_dataset(mode, path, resolution=(256, 256)):
|
743 |
+
query_mode = 'first' if 'q_first' in mode else 'strided'
|
744 |
+
datasets = {}
|
745 |
+
if 'jhmdb' in mode:
|
746 |
+
key = 'jhmdb'
|
747 |
+
dataset = create_jhmdb_dataset(path[key], resolution)
|
748 |
+
datasets[key] = CustomDataset(dataset, key)
|
749 |
+
if 'davis' in mode:
|
750 |
+
key = 'davis'
|
751 |
+
dataset = create_davis_dataset(path[key], query_mode, False, resolution=resolution)
|
752 |
+
datasets[key] = CustomDataset(dataset, key)
|
753 |
+
if 'robotics' in mode:
|
754 |
+
key = 'robotics'
|
755 |
+
dataset = create_rgb_stacking_dataset(path[key], query_mode, resolution)
|
756 |
+
datasets[key] = CustomDataset(dataset, key)
|
757 |
+
if 'kinetics' in mode:
|
758 |
+
key = 'kinetics'
|
759 |
+
dataset = create_kinetics_dataset(path[key], query_mode, resolution)
|
760 |
+
datasets[key] = CustomDataset(dataset, key)
|
761 |
+
if 'robotap' in mode:
|
762 |
+
key = 'robotap'
|
763 |
+
dataset = create_robotap_dataset(path[key], query_mode, resolution)
|
764 |
+
datasets[key] = CustomDataset(dataset, key)
|
765 |
+
|
766 |
+
if len(datasets) == 0:
|
767 |
+
raise ValueError(f'No dataset found for mode {mode}.')
|
768 |
+
|
769 |
+
return datasets
|
770 |
+
|
771 |
+
|
772 |
+
if __name__ == '__main__':
|
773 |
+
# Disable all GPUS
|
774 |
+
tf.config.set_visible_devices([], 'GPU')
|
775 |
+
visible_devices = tf.config.get_visible_devices()
|
776 |
+
for device in visible_devices:
|
777 |
+
assert device.device_type != 'GPU'
|
778 |
+
|
779 |
+
dataset_name = 'davis'
|
780 |
+
dataset_path = '/media/data2/PointTracking/tapvid/tapnet_dataset/tapvid_davis/tapvid_davis.pkl'
|
781 |
+
|
782 |
+
dataset = get_eval_dataset(dataset_name, dataset_path, 'strided', (256, 256))
|
783 |
+
breakpoint()
|
784 |
+
pass
|
locotrack_pytorch/data/kubric_data.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Mapping
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import functools
|
7 |
+
import tensorflow_datasets as tfds
|
8 |
+
import tensorflow as tf
|
9 |
+
import torch.distributed
|
10 |
+
from kubric.challenges.point_tracking.dataset import add_tracks
|
11 |
+
|
12 |
+
|
13 |
+
# Disable all GPUS
|
14 |
+
tf.config.set_visible_devices([], 'GPU')
|
15 |
+
visible_devices = tf.config.get_visible_devices()
|
16 |
+
for device in visible_devices:
|
17 |
+
assert device.device_type != 'GPU'
|
18 |
+
|
19 |
+
|
20 |
+
def default_color_augmentation_fn(
|
21 |
+
inputs: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
|
22 |
+
"""Standard color augmentation for videos.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
inputs: A DatasetElement containing the item 'video' which will have
|
26 |
+
augmentations applied to it.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
A DatasetElement with all the same data as the original, except that
|
30 |
+
the video has augmentations applied.
|
31 |
+
"""
|
32 |
+
zero_centering_image = True
|
33 |
+
prob_color_augment = 0.8
|
34 |
+
prob_color_drop = 0.2
|
35 |
+
|
36 |
+
frames = inputs['video']
|
37 |
+
if frames.dtype != tf.float32:
|
38 |
+
raise ValueError('`frames` should be in float32.')
|
39 |
+
|
40 |
+
def color_augment(video: tf.Tensor) -> tf.Tensor:
|
41 |
+
"""Do standard color augmentations."""
|
42 |
+
# Note the same augmentation will be applied to all frames of the video.
|
43 |
+
if zero_centering_image:
|
44 |
+
video = 0.5 * (video + 1.0)
|
45 |
+
video = tf.image.random_brightness(video, max_delta=32. / 255.)
|
46 |
+
video = tf.image.random_saturation(video, lower=0.6, upper=1.4)
|
47 |
+
video = tf.image.random_contrast(video, lower=0.6, upper=1.4)
|
48 |
+
video = tf.image.random_hue(video, max_delta=0.2)
|
49 |
+
video = tf.clip_by_value(video, 0.0, 1.0)
|
50 |
+
if zero_centering_image:
|
51 |
+
video = 2 * (video-0.5)
|
52 |
+
return video
|
53 |
+
|
54 |
+
def color_drop(video: tf.Tensor) -> tf.Tensor:
|
55 |
+
video = tf.image.rgb_to_grayscale(video)
|
56 |
+
video = tf.tile(video, [1, 1, 1, 1, 3])
|
57 |
+
return video
|
58 |
+
|
59 |
+
# Eventually applies color augmentation.
|
60 |
+
coin_toss_color_augment = tf.random.uniform(
|
61 |
+
[], minval=0, maxval=1, dtype=tf.float32)
|
62 |
+
frames = tf.cond(
|
63 |
+
pred=tf.less(coin_toss_color_augment,
|
64 |
+
tf.cast(prob_color_augment, tf.float32)),
|
65 |
+
true_fn=lambda: color_augment(frames),
|
66 |
+
false_fn=lambda: frames)
|
67 |
+
|
68 |
+
# Eventually applies color drop.
|
69 |
+
coin_toss_color_drop = tf.random.uniform(
|
70 |
+
[], minval=0, maxval=1, dtype=tf.float32)
|
71 |
+
frames = tf.cond(
|
72 |
+
pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)),
|
73 |
+
true_fn=lambda: color_drop(frames),
|
74 |
+
false_fn=lambda: frames)
|
75 |
+
result = {**inputs}
|
76 |
+
result['video'] = frames
|
77 |
+
|
78 |
+
return result
|
79 |
+
|
80 |
+
|
81 |
+
def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset:
|
82 |
+
return ds.map(
|
83 |
+
default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
84 |
+
|
85 |
+
|
86 |
+
def create_point_tracking_dataset(
|
87 |
+
data_dir=None,
|
88 |
+
color_augmentation=True,
|
89 |
+
train_size=(256, 256),
|
90 |
+
shuffle_buffer_size=256,
|
91 |
+
split='train',
|
92 |
+
# batch_dims=tuple(),
|
93 |
+
batch_size=1,
|
94 |
+
repeat=True,
|
95 |
+
vflip=False,
|
96 |
+
random_crop=True,
|
97 |
+
tracks_to_sample=256,
|
98 |
+
sampling_stride=4,
|
99 |
+
max_seg_id=40,
|
100 |
+
max_sampled_frac=0.1,
|
101 |
+
num_parallel_point_extraction_calls=16,
|
102 |
+
**kwargs):
|
103 |
+
"""Construct a dataset for point tracking using Kubric.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
train_size: Tuple of 2 ints. Cropped output will be at this resolution
|
107 |
+
shuffle_buffer_size: Int. Size of the shuffle buffer
|
108 |
+
split: Which split to construct from Kubric. Can be 'train' or
|
109 |
+
'validation'.
|
110 |
+
batch_dims: Sequence of ints. Add multiple examples into a batch of this
|
111 |
+
shape.
|
112 |
+
repeat: Bool. whether to repeat the dataset.
|
113 |
+
vflip: Bool. whether to vertically flip the dataset to test generalization.
|
114 |
+
random_crop: Bool. whether to randomly crop videos
|
115 |
+
tracks_to_sample: Int. Total number of tracks to sample per video.
|
116 |
+
sampling_stride: Int. For efficiency, query points are sampled from a
|
117 |
+
random grid of this stride.
|
118 |
+
max_seg_id: Int. The maxium segment id in the video. Note the size of
|
119 |
+
the to graph is proportional to this number, so prefer small values.
|
120 |
+
max_sampled_frac: Float. The maximum fraction of points to sample from each
|
121 |
+
object, out of all points that lie on the sampling grid.
|
122 |
+
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
|
123 |
+
map function for point extraction.
|
124 |
+
snap_to_occluder: If true, query points within 1 pixel of occlusion
|
125 |
+
boundaries will track the occluding surface rather than the background.
|
126 |
+
This results in models which are biased to track foreground objects
|
127 |
+
instead of background. Whether this is desirable depends on downstream
|
128 |
+
applications.
|
129 |
+
**kwargs: additional args to pass to tfds.load.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
The dataset generator.
|
133 |
+
"""
|
134 |
+
ds = tfds.load(
|
135 |
+
'panning_movi_e/256x256',
|
136 |
+
data_dir=data_dir,
|
137 |
+
shuffle_files=shuffle_buffer_size is not None,
|
138 |
+
**kwargs)
|
139 |
+
|
140 |
+
ds = ds[split]
|
141 |
+
if repeat:
|
142 |
+
ds = ds.repeat()
|
143 |
+
ds = ds.map(
|
144 |
+
functools.partial(
|
145 |
+
add_tracks,
|
146 |
+
train_size=train_size,
|
147 |
+
vflip=vflip,
|
148 |
+
random_crop=random_crop,
|
149 |
+
tracks_to_sample=tracks_to_sample,
|
150 |
+
sampling_stride=sampling_stride,
|
151 |
+
max_seg_id=max_seg_id,
|
152 |
+
max_sampled_frac=max_sampled_frac),
|
153 |
+
num_parallel_calls=num_parallel_point_extraction_calls)
|
154 |
+
if shuffle_buffer_size is not None:
|
155 |
+
ds = ds.shuffle(shuffle_buffer_size)
|
156 |
+
|
157 |
+
ds = ds.batch(batch_size)
|
158 |
+
|
159 |
+
if color_augmentation:
|
160 |
+
ds = add_default_data_augmentation(ds)
|
161 |
+
ds = tfds.as_numpy(ds)
|
162 |
+
|
163 |
+
it = iter(ds)
|
164 |
+
while True:
|
165 |
+
data = next(it)
|
166 |
+
yield data
|
167 |
+
|
168 |
+
|
169 |
+
class KubricData:
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
global_rank,
|
173 |
+
data_dir,
|
174 |
+
**kwargs
|
175 |
+
):
|
176 |
+
self.global_rank = global_rank
|
177 |
+
|
178 |
+
if self.global_rank == 0:
|
179 |
+
self.data = create_point_tracking_dataset(
|
180 |
+
data_dir=data_dir,
|
181 |
+
**kwargs
|
182 |
+
)
|
183 |
+
|
184 |
+
def __getitem__(self, idx):
|
185 |
+
if self.global_rank == 0:
|
186 |
+
batch_all = next(self.data)
|
187 |
+
batch_list = []
|
188 |
+
|
189 |
+
world_size = torch.distributed.get_world_size()
|
190 |
+
batch_size = batch_all['video'].shape[0] // world_size
|
191 |
+
|
192 |
+
|
193 |
+
for i in range(world_size):
|
194 |
+
batch = {}
|
195 |
+
for k, v in batch_all.items():
|
196 |
+
if isinstance(v, (np.ndarray, torch.Tensor)):
|
197 |
+
batch[k] = torch.tensor(v[i * batch_size: (i + 1) * batch_size])
|
198 |
+
batch_list.append(batch)
|
199 |
+
else:
|
200 |
+
batch_list = [None] * torch.distributed.get_world_size()
|
201 |
+
|
202 |
+
|
203 |
+
batch = [None]
|
204 |
+
torch.distributed.scatter_object_list(batch, batch_list, src=0)
|
205 |
+
|
206 |
+
return batch[0]
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == '__main__':
|
210 |
+
|
211 |
+
import torch.nn as nn
|
212 |
+
import lightning as L
|
213 |
+
from lightning.pytorch.strategies import DDPStrategy
|
214 |
+
|
215 |
+
class Model(L.LightningModule):
|
216 |
+
def __init__(self):
|
217 |
+
super().__init__()
|
218 |
+
self.model = nn.Linear(256 * 256 * 3 * 24, 1)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
return self.model(x)
|
222 |
+
|
223 |
+
def training_step(self, batch, batch_idx):
|
224 |
+
breakpoint()
|
225 |
+
x = batch['video']
|
226 |
+
x = x.reshape(x.shape[0], -1)
|
227 |
+
y = self(x)
|
228 |
+
return y
|
229 |
+
|
230 |
+
def configure_optimizers(self):
|
231 |
+
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
232 |
+
|
233 |
+
model = Model()
|
234 |
+
|
235 |
+
trainer = L.Trainer(accelerator="cpu", strategy=DDPStrategy(), max_steps=1000, devices=1)
|
236 |
+
|
237 |
+
dataloader = KubricData(
|
238 |
+
global_rank=trainer.global_rank,
|
239 |
+
data_dir='/media/data2/PointTracking/tensorflow_datasets',
|
240 |
+
batch_size=1 * trainer.world_size,
|
241 |
+
)
|
242 |
+
|
243 |
+
trainer.fit(model, dataloader)
|
locotrack_pytorch/environment.yml
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: locotrack-pytorch
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- _openmp_mutex=5.1=1_gnu
|
7 |
+
- bzip2=1.0.8=h5eee18b_6
|
8 |
+
- ca-certificates=2024.7.2=h06a4308_0
|
9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
10 |
+
- libffi=3.4.4=h6a678d5_1
|
11 |
+
- libgcc-ng=11.2.0=h1234567_1
|
12 |
+
- libgomp=11.2.0=h1234567_1
|
13 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
14 |
+
- libuuid=1.41.5=h5eee18b_0
|
15 |
+
- ncurses=6.4=h6a678d5_0
|
16 |
+
- openssl=3.0.14=h5eee18b_0
|
17 |
+
- pip=24.0=py311h06a4308_0
|
18 |
+
- python=3.11.9=h955ad1f_0
|
19 |
+
- readline=8.2=h5eee18b_0
|
20 |
+
- setuptools=69.5.1=py311h06a4308_0
|
21 |
+
- sqlite=3.45.3=h5eee18b_0
|
22 |
+
- tk=8.6.14=h39e8969_0
|
23 |
+
- tzdata=2024a=h04d1e81_0
|
24 |
+
- wheel=0.43.0=py311h06a4308_0
|
25 |
+
- xz=5.4.6=h5eee18b_1
|
26 |
+
- zlib=1.2.13=h5eee18b_1
|
27 |
+
- pip:
|
28 |
+
- absl-py==2.1.0
|
29 |
+
- aiohttp==3.9.5
|
30 |
+
- aiosignal==1.3.1
|
31 |
+
- array-record==0.5.1
|
32 |
+
- asttokens==2.4.1
|
33 |
+
- astunparse==1.6.3
|
34 |
+
- attrs==23.2.0
|
35 |
+
- certifi==2024.7.4
|
36 |
+
- charset-normalizer==3.3.2
|
37 |
+
- click==8.1.7
|
38 |
+
- contourpy==1.2.1
|
39 |
+
- cycler==0.12.1
|
40 |
+
- decorator==5.1.1
|
41 |
+
- dm-tree==0.1.8
|
42 |
+
- docker-pycreds==0.4.0
|
43 |
+
- docstring-parser==0.16
|
44 |
+
- einshape==1.0
|
45 |
+
- etils==1.9.2
|
46 |
+
- executing==2.0.1
|
47 |
+
- filelock==3.13.1
|
48 |
+
- flatbuffers==24.3.25
|
49 |
+
- fonttools==4.53.1
|
50 |
+
- frozenlist==1.4.1
|
51 |
+
- fsspec==2024.2.0
|
52 |
+
- gast==0.6.0
|
53 |
+
- gitdb==4.0.11
|
54 |
+
- gitpython==3.1.43
|
55 |
+
- google-pasta==0.2.0
|
56 |
+
- googleapis-common-protos==1.63.2
|
57 |
+
- grpcio==1.65.1
|
58 |
+
- h5py==3.11.0
|
59 |
+
- idna==3.7
|
60 |
+
- immutabledict==4.2.0
|
61 |
+
- importlib-resources==6.4.0
|
62 |
+
- ipython==8.26.0
|
63 |
+
- jedi==0.19.1
|
64 |
+
- jinja2==3.1.3
|
65 |
+
- keras==3.4.1
|
66 |
+
- kiwisolver==1.4.5
|
67 |
+
- libclang==18.1.1
|
68 |
+
- lightning==2.3.3
|
69 |
+
- lightning-utilities==0.11.6
|
70 |
+
- markdown==3.6
|
71 |
+
- markdown-it-py==3.0.0
|
72 |
+
- markupsafe==2.1.5
|
73 |
+
- matplotlib==3.9.1
|
74 |
+
- matplotlib-inline==0.1.7
|
75 |
+
- mdurl==0.1.2
|
76 |
+
- mediapy==1.2.2
|
77 |
+
- ml-dtypes==0.4.0
|
78 |
+
- mpmath==1.3.0
|
79 |
+
- multidict==6.0.5
|
80 |
+
- namex==0.0.8
|
81 |
+
- networkx==3.2.1
|
82 |
+
- numpy==1.26.3
|
83 |
+
- nvidia-cublas-cu12==12.4.2.65
|
84 |
+
- nvidia-cuda-cupti-cu12==12.4.99
|
85 |
+
- nvidia-cuda-nvrtc-cu12==12.4.99
|
86 |
+
- nvidia-cuda-runtime-cu12==12.4.99
|
87 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
88 |
+
- nvidia-cufft-cu12==11.2.0.44
|
89 |
+
- nvidia-curand-cu12==10.3.5.119
|
90 |
+
- nvidia-cusolver-cu12==11.6.0.99
|
91 |
+
- nvidia-cusparse-cu12==12.3.0.142
|
92 |
+
- nvidia-nccl-cu12==2.20.5
|
93 |
+
- nvidia-nvjitlink-cu12==12.4.99
|
94 |
+
- nvidia-nvtx-cu12==12.4.99
|
95 |
+
- openexr==3.2.4
|
96 |
+
- opt-einsum==3.3.0
|
97 |
+
- optree==0.12.1
|
98 |
+
- packaging==24.1
|
99 |
+
- parso==0.8.4
|
100 |
+
- pexpect==4.9.0
|
101 |
+
- pillow==10.2.0
|
102 |
+
- platformdirs==4.2.2
|
103 |
+
- promise==2.3
|
104 |
+
- prompt-toolkit==3.0.47
|
105 |
+
- protobuf==4.25.4
|
106 |
+
- psutil==6.0.0
|
107 |
+
- ptyprocess==0.7.0
|
108 |
+
- pure-eval==0.2.3
|
109 |
+
- pyarrow==17.0.0
|
110 |
+
- pygments==2.18.0
|
111 |
+
- pyparsing==3.1.2
|
112 |
+
- python-dateutil==2.9.0.post0
|
113 |
+
- pytorch-lightning==2.3.3
|
114 |
+
- pyyaml==6.0.1
|
115 |
+
- requests==2.32.3
|
116 |
+
- rich==13.7.1
|
117 |
+
- scipy==1.14.0
|
118 |
+
- sentry-sdk==2.11.0
|
119 |
+
- setproctitle==1.3.3
|
120 |
+
- simple-parsing==0.1.5
|
121 |
+
- six==1.16.0
|
122 |
+
- smmap==5.0.1
|
123 |
+
- stack-data==0.6.3
|
124 |
+
- sympy==1.12
|
125 |
+
- tensorboard==2.17.0
|
126 |
+
- tensorboard-data-server==0.7.2
|
127 |
+
- tensorflow==2.17.0
|
128 |
+
- tensorflow-addons==0.23.0
|
129 |
+
- tensorflow-datasets==4.9.6
|
130 |
+
- tensorflow-graphics==2021.12.3
|
131 |
+
- tensorflow-io-gcs-filesystem==0.37.1
|
132 |
+
- tensorflow-metadata==1.15.0
|
133 |
+
- termcolor==2.4.0
|
134 |
+
- toml==0.10.2
|
135 |
+
- torch==2.4.0+cu124
|
136 |
+
- torchaudio==2.4.0+cu124
|
137 |
+
- torchmetrics==1.4.0.post0
|
138 |
+
- torchvision==0.19.0+cu124
|
139 |
+
- tqdm==4.66.4
|
140 |
+
- traitlets==5.14.3
|
141 |
+
- trimesh==4.4.3
|
142 |
+
- triton==3.0.0
|
143 |
+
- typeguard==2.13.3
|
144 |
+
- typing-extensions==4.9.0
|
145 |
+
- urllib3==2.2.2
|
146 |
+
- wandb==0.17.5
|
147 |
+
- wcwidth==0.2.13
|
148 |
+
- werkzeug==3.0.3
|
149 |
+
- wrapt==1.16.0
|
150 |
+
- yarl==1.9.4
|
151 |
+
- zipp==3.19.2
|
locotrack_pytorch/experiment.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import configparser
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from functools import partial
|
6 |
+
from typing import Any, Dict, Optional, Union
|
7 |
+
|
8 |
+
import lightning as L
|
9 |
+
from lightning.pytorch import seed_everything
|
10 |
+
from lightning.pytorch.loggers import WandbLogger
|
11 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
from data.kubric_data import KubricData
|
16 |
+
from models.locotrack_model import LocoTrack
|
17 |
+
import model_utils
|
18 |
+
from data.evaluation_datasets import get_eval_dataset
|
19 |
+
|
20 |
+
|
21 |
+
class LocoTrackModel(L.LightningModule):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
25 |
+
model_forward_kwargs: Optional[Dict[str, Any]] = None,
|
26 |
+
loss_name: Optional[str] = 'tapir_loss',
|
27 |
+
loss_kwargs: Optional[Dict[str, Any]] = None,
|
28 |
+
query_first: Optional[bool] = False,
|
29 |
+
optimizer_name: Optional[str] = 'Adam',
|
30 |
+
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
31 |
+
scheduler_name: Optional[str] = 'OneCycleLR',
|
32 |
+
scheduler_kwargs: Optional[Dict[str, Any]] = None,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
self.model = LocoTrack(**(model_kwargs or {}))
|
36 |
+
self.model_forward_kwargs = model_forward_kwargs or {}
|
37 |
+
self.loss = partial(model_utils.__dict__[loss_name], **(loss_kwargs or {}))
|
38 |
+
self.query_first = query_first
|
39 |
+
|
40 |
+
self.optimizer_name = optimizer_name
|
41 |
+
self.optimizer_kwargs = optimizer_kwargs or {'lr': 2e-3}
|
42 |
+
self.scheduler_name = scheduler_name
|
43 |
+
self.scheduler_kwargs = scheduler_kwargs or {'max_lr': 2e-3, 'pct_start': 0.05, 'total_steps': 300000}
|
44 |
+
|
45 |
+
def training_step(self, batch, batch_idx):
|
46 |
+
output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs)
|
47 |
+
loss, loss_scalars = self.loss(batch, output)
|
48 |
+
|
49 |
+
self.log_dict(
|
50 |
+
{f'train/{k}': v.item() for k, v in loss_scalars.items()},
|
51 |
+
logger=True,
|
52 |
+
on_step=True,
|
53 |
+
sync_dist=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
return loss
|
57 |
+
|
58 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=None):
|
59 |
+
output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs)
|
60 |
+
loss, loss_scalars = self.loss(batch, output)
|
61 |
+
metrics = model_utils.eval_batch(batch, output, query_first=self.query_first)
|
62 |
+
|
63 |
+
if self.trainer.global_rank == 0:
|
64 |
+
log_prefix = 'val/'
|
65 |
+
if dataloader_idx is not None:
|
66 |
+
log_prefix = f'val/data_{dataloader_idx}/'
|
67 |
+
|
68 |
+
self.log_dict(
|
69 |
+
{log_prefix + k: v for k, v in loss_scalars.items()},
|
70 |
+
logger=True,
|
71 |
+
rank_zero_only=True,
|
72 |
+
)
|
73 |
+
self.log_dict(
|
74 |
+
{log_prefix + k: v.item() for k, v in metrics.items()},
|
75 |
+
logger=True,
|
76 |
+
rank_zero_only=True,
|
77 |
+
)
|
78 |
+
logging.info(f"Batch {batch_idx}: {metrics}")
|
79 |
+
|
80 |
+
def test_step(self, batch, batch_idx, dataloader_idx=None):
|
81 |
+
output = self.model(batch['video'], batch['query_points'], **self.model_forward_kwargs)
|
82 |
+
loss, loss_scalars = self.loss(batch, output)
|
83 |
+
metrics = model_utils.eval_batch(batch, output, query_first=self.query_first)
|
84 |
+
|
85 |
+
if self.trainer.global_rank == 0:
|
86 |
+
log_prefix = 'test/'
|
87 |
+
if dataloader_idx is not None:
|
88 |
+
log_prefix = f'test/data_{dataloader_idx}/'
|
89 |
+
|
90 |
+
self.log_dict(
|
91 |
+
{log_prefix + k: v for k, v in loss_scalars.items()},
|
92 |
+
logger=True,
|
93 |
+
rank_zero_only=True,
|
94 |
+
)
|
95 |
+
self.log_dict(
|
96 |
+
{log_prefix + k: v.item() for k, v in metrics.items()},
|
97 |
+
logger=True,
|
98 |
+
rank_zero_only=True,
|
99 |
+
)
|
100 |
+
logging.info(f"Batch {batch_idx}: {metrics}")
|
101 |
+
|
102 |
+
def configure_optimizers(self):
|
103 |
+
weights = [p for n, p in self.named_parameters() if 'bias' not in n]
|
104 |
+
bias = [p for n, p in self.named_parameters() if 'bias' in n]
|
105 |
+
|
106 |
+
optimizer = torch.optim.__dict__[self.optimizer_name](
|
107 |
+
[
|
108 |
+
{'params': weights, **self.optimizer_kwargs},
|
109 |
+
{'params': bias, **self.optimizer_kwargs, 'weight_decay': 0.}
|
110 |
+
]
|
111 |
+
)
|
112 |
+
scheduler = torch.optim.lr_scheduler.__dict__[self.scheduler_name](optimizer, **self.scheduler_kwargs)
|
113 |
+
|
114 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
115 |
+
|
116 |
+
|
117 |
+
def train(
|
118 |
+
mode: str,
|
119 |
+
save_path: str,
|
120 |
+
val_dataset_path: str,
|
121 |
+
ckpt_path: str = None,
|
122 |
+
kubric_dir: str = '',
|
123 |
+
precision: str = '32',
|
124 |
+
batch_size: int = 1,
|
125 |
+
val_check_interval: Union[int, float] = 5000,
|
126 |
+
log_every_n_steps: int = 10,
|
127 |
+
gradient_clip_val: float = 1.0,
|
128 |
+
max_steps: int = 300_000,
|
129 |
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
130 |
+
model_forward_kwargs: Optional[Dict[str, Any]] = None,
|
131 |
+
loss_name: str = 'tapir_loss',
|
132 |
+
loss_kwargs: Optional[Dict[str, Any]] = None,
|
133 |
+
optimizer_name: str = 'Adam',
|
134 |
+
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
135 |
+
scheduler_name: str = 'OneCycleLR',
|
136 |
+
scheduler_kwargs: Optional[Dict[str, Any]] = None,
|
137 |
+
# query_first: bool = False,
|
138 |
+
):
|
139 |
+
"""Train the LocoTrack model with specified configurations."""
|
140 |
+
seed_everything(42, workers=True)
|
141 |
+
|
142 |
+
model = LocoTrackModel(
|
143 |
+
model_kwargs=model_kwargs,
|
144 |
+
model_forward_kwargs=model_forward_kwargs,
|
145 |
+
loss_name=loss_name,
|
146 |
+
loss_kwargs=loss_kwargs,
|
147 |
+
query_first='q_first' in mode,
|
148 |
+
optimizer_name=optimizer_name,
|
149 |
+
optimizer_kwargs=optimizer_kwargs,
|
150 |
+
scheduler_name=scheduler_name,
|
151 |
+
scheduler_kwargs=scheduler_kwargs,
|
152 |
+
)
|
153 |
+
if ckpt_path is not None and 'train' in mode:
|
154 |
+
model.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
155 |
+
|
156 |
+
logger = WandbLogger(project='LocoTrack_Pytorch', save_dir=save_path, id=os.path.basename(save_path))
|
157 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
158 |
+
checkpoint_callback = ModelCheckpoint(
|
159 |
+
dirpath=save_path,
|
160 |
+
save_last=True,
|
161 |
+
save_top_k=3,
|
162 |
+
mode="max",
|
163 |
+
monitor="val/average_pts_within_thresh",
|
164 |
+
auto_insert_metric_name=True,
|
165 |
+
save_on_train_epoch_end=False,
|
166 |
+
)
|
167 |
+
|
168 |
+
eval_dataset = get_eval_dataset(
|
169 |
+
mode=mode,
|
170 |
+
path=val_dataset_path,
|
171 |
+
)
|
172 |
+
eval_dataloder = {
|
173 |
+
k: DataLoader(
|
174 |
+
v,
|
175 |
+
batch_size=1,
|
176 |
+
shuffle=False,
|
177 |
+
) for k, v in eval_dataset.items()
|
178 |
+
}
|
179 |
+
|
180 |
+
if 'train' in mode:
|
181 |
+
trainer = L.Trainer(
|
182 |
+
strategy='ddp',
|
183 |
+
logger=logger,
|
184 |
+
precision=precision,
|
185 |
+
val_check_interval=val_check_interval,
|
186 |
+
log_every_n_steps=log_every_n_steps,
|
187 |
+
gradient_clip_val=gradient_clip_val,
|
188 |
+
max_steps=max_steps,
|
189 |
+
sync_batchnorm=True,
|
190 |
+
callbacks=[checkpoint_callback, lr_monitor],
|
191 |
+
)
|
192 |
+
train_dataloader = KubricData(
|
193 |
+
global_rank=trainer.global_rank,
|
194 |
+
data_dir=kubric_dir,
|
195 |
+
batch_size=batch_size * trainer.world_size,
|
196 |
+
)
|
197 |
+
trainer.fit(model, train_dataloader, eval_dataloder, ckpt_path=ckpt_path)
|
198 |
+
elif 'eval' in mode:
|
199 |
+
trainer = L.Trainer(strategy='ddp', logger=logger, precision=precision)
|
200 |
+
trainer.test(model, eval_dataloder, ckpt_path=ckpt_path)
|
201 |
+
else:
|
202 |
+
raise ValueError(f"Invalid mode: {mode}")
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
parser = argparse.ArgumentParser(description="Train or evaluate the LocoTrack model.")
|
206 |
+
parser.add_argument('--config', type=str, default='config.ini', help="Path to the configuration file.")
|
207 |
+
parser.add_argument('--mode', type=str, required=True, help="Mode to run: 'train' or 'eval' with optional 'q_first' and the name of evaluation dataset.")
|
208 |
+
parser.add_argument('--ckpt_path', type=str, default=None, help="Path to the checkpoint file")
|
209 |
+
parser.add_argument('--save_path', type=str, default='snapshots', help="Path to save the logs and checkpoints.")
|
210 |
+
|
211 |
+
args = parser.parse_args()
|
212 |
+
config = configparser.ConfigParser()
|
213 |
+
config.read(args.config)
|
214 |
+
|
215 |
+
# Extract parameters from the config file
|
216 |
+
train_params = {
|
217 |
+
'mode': args.mode,
|
218 |
+
'ckpt_path': args.ckpt_path,
|
219 |
+
'save_path': args.save_path,
|
220 |
+
'val_dataset_path': eval(config.get('TRAINING', 'val_dataset_path', fallback='{}')),
|
221 |
+
'kubric_dir': config.get('TRAINING', 'kubric_dir', fallback=''),
|
222 |
+
'precision': config.get('TRAINING', 'precision', fallback='32'),
|
223 |
+
'batch_size': config.getint('TRAINING', 'batch_size', fallback=1),
|
224 |
+
'val_check_interval': config.getfloat('TRAINING', 'val_check_interval', fallback=5000),
|
225 |
+
'log_every_n_steps': config.getint('TRAINING', 'log_every_n_steps', fallback=10),
|
226 |
+
'gradient_clip_val': config.getfloat('TRAINING', 'gradient_clip_val', fallback=1.0),
|
227 |
+
'max_steps': config.getint('TRAINING', 'max_steps', fallback=300000),
|
228 |
+
'model_kwargs': eval(config.get('MODEL', 'model_kwargs', fallback='{}')),
|
229 |
+
'model_forward_kwargs': eval(config.get('MODEL', 'model_forward_kwargs', fallback='{}')),
|
230 |
+
'loss_name': config.get('LOSS', 'loss_name', fallback='tapir_loss'),
|
231 |
+
'loss_kwargs': eval(config.get('LOSS', 'loss_kwargs', fallback='{}')),
|
232 |
+
'optimizer_name': config.get('OPTIMIZER', 'optimizer_name', fallback='Adam'),
|
233 |
+
'optimizer_kwargs': eval(config.get('OPTIMIZER', 'optimizer_kwargs', fallback='{"lr": 2e-3}')),
|
234 |
+
'scheduler_name': config.get('SCHEDULER', 'scheduler_name', fallback='OneCycleLR'),
|
235 |
+
'scheduler_kwargs': eval(config.get('SCHEDULER', 'scheduler_kwargs', fallback='{"max_lr": 2e-3, "pct_start": 0.05, "total_steps": 300000}')),
|
236 |
+
}
|
237 |
+
|
238 |
+
train(**train_params)
|
locotrack_pytorch/model_utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from models.utils import convert_grid_coordinates
|
7 |
+
from data.evaluation_datasets import compute_tapvid_metrics
|
8 |
+
|
9 |
+
def huber_loss(tracks, target_points, occluded, delta=4.0, reduction_axes=(1, 2)):
|
10 |
+
"""Huber loss for point trajectories."""
|
11 |
+
error = tracks - target_points
|
12 |
+
distsqr = torch.sum(error ** 2, dim=-1)
|
13 |
+
dist = torch.sqrt(distsqr + 1e-12) # add eps to prevent nan
|
14 |
+
loss_huber = torch.where(dist < delta, distsqr / 2, delta * (torch.abs(dist) - delta / 2))
|
15 |
+
loss_huber = loss_huber * (1.0 - occluded.float())
|
16 |
+
|
17 |
+
if reduction_axes:
|
18 |
+
loss_huber = torch.mean(loss_huber, dim=reduction_axes)
|
19 |
+
|
20 |
+
return loss_huber
|
21 |
+
|
22 |
+
def prob_loss(tracks, expd, target_points, occluded, expected_dist_thresh=8.0, reduction_axes=(1, 2)):
|
23 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
24 |
+
err = torch.sum((tracks - target_points) ** 2, dim=-1)
|
25 |
+
invalid = (err > expected_dist_thresh ** 2).float()
|
26 |
+
logprob = F.binary_cross_entropy_with_logits(expd, invalid, reduction='none')
|
27 |
+
logprob = logprob * (1.0 - occluded.float())
|
28 |
+
|
29 |
+
if reduction_axes:
|
30 |
+
logprob = torch.mean(logprob, dim=reduction_axes)
|
31 |
+
|
32 |
+
return logprob
|
33 |
+
|
34 |
+
def tapnet_loss(points, occlusion, target_points, target_occ, shape, mask=None, expected_dist=None,
|
35 |
+
position_loss_weight=0.05, expected_dist_thresh=6.0, huber_loss_delta=4.0,
|
36 |
+
rebalance_factor=None, occlusion_loss_mask=None):
|
37 |
+
"""TAPNet loss."""
|
38 |
+
|
39 |
+
if mask is None:
|
40 |
+
mask = torch.tensor(1.0)
|
41 |
+
|
42 |
+
points = convert_grid_coordinates(points, shape[3:1:-1], (256, 256), coordinate_format='xy')
|
43 |
+
target_points = convert_grid_coordinates(target_points, shape[3:1:-1], (256, 256), coordinate_format='xy')
|
44 |
+
|
45 |
+
loss_huber = huber_loss(points, target_points, target_occ, delta=huber_loss_delta, reduction_axes=None) * mask
|
46 |
+
loss_huber = torch.mean(loss_huber) * position_loss_weight
|
47 |
+
|
48 |
+
if expected_dist is None:
|
49 |
+
loss_prob = torch.tensor(0.0)
|
50 |
+
else:
|
51 |
+
loss_prob = prob_loss(points.detach(), expected_dist, target_points, target_occ, expected_dist_thresh, reduction_axes=None) * mask
|
52 |
+
loss_prob = torch.mean(loss_prob)
|
53 |
+
|
54 |
+
target_occ = target_occ.to(dtype=occlusion.dtype)
|
55 |
+
loss_occ = F.binary_cross_entropy_with_logits(occlusion, target_occ, reduction='none') * mask
|
56 |
+
|
57 |
+
if rebalance_factor is not None:
|
58 |
+
loss_occ = loss_occ * ((1 + rebalance_factor) - rebalance_factor * target_occ)
|
59 |
+
|
60 |
+
if occlusion_loss_mask is not None:
|
61 |
+
loss_occ = loss_occ * occlusion_loss_mask
|
62 |
+
|
63 |
+
loss_occ = torch.mean(loss_occ)
|
64 |
+
|
65 |
+
return loss_huber, loss_occ, loss_prob
|
66 |
+
|
67 |
+
|
68 |
+
def tapir_loss(
|
69 |
+
batch,
|
70 |
+
output,
|
71 |
+
position_loss_weight=0.05,
|
72 |
+
expected_dist_thresh=6.0,
|
73 |
+
):
|
74 |
+
loss_scalars = {}
|
75 |
+
loss_huber, loss_occ, loss_prob = tapnet_loss(
|
76 |
+
output['tracks'],
|
77 |
+
output['occlusion'],
|
78 |
+
batch['target_points'],
|
79 |
+
batch['occluded'],
|
80 |
+
batch['video'].shape, # pytype: disable=attribute-error # numpy-scalars
|
81 |
+
expected_dist=output['expected_dist']
|
82 |
+
if 'expected_dist' in output
|
83 |
+
else None,
|
84 |
+
position_loss_weight=position_loss_weight,
|
85 |
+
expected_dist_thresh=expected_dist_thresh,
|
86 |
+
)
|
87 |
+
loss = loss_huber + loss_occ + loss_prob
|
88 |
+
loss_scalars['position_loss'] = loss_huber
|
89 |
+
loss_scalars['occlusion_loss'] = loss_occ
|
90 |
+
if 'expected_dist' in output:
|
91 |
+
loss_scalars['prob_loss'] = loss_prob
|
92 |
+
|
93 |
+
if 'unrefined_tracks' in output:
|
94 |
+
for l in range(len(output['unrefined_tracks'])):
|
95 |
+
loss_huber, loss_occ, loss_prob = tapnet_loss(
|
96 |
+
output['unrefined_tracks'][l],
|
97 |
+
output['unrefined_occlusion'][l],
|
98 |
+
batch['target_points'],
|
99 |
+
batch['occluded'],
|
100 |
+
batch['video'].shape, # pytype: disable=attribute-error # numpy-scalars
|
101 |
+
expected_dist=output['unrefined_expected_dist'][l]
|
102 |
+
if 'unrefined_expected_dist' in output
|
103 |
+
else None,
|
104 |
+
position_loss_weight=position_loss_weight,
|
105 |
+
expected_dist_thresh=expected_dist_thresh,
|
106 |
+
)
|
107 |
+
loss = loss + loss_huber + loss_occ + loss_prob
|
108 |
+
loss_scalars[f'position_loss_{l}'] = loss_huber
|
109 |
+
loss_scalars[f'occlusion_loss_{l}'] = loss_occ
|
110 |
+
if 'unrefined_expected_dist' in output:
|
111 |
+
loss_scalars[f'prob_loss_{l}'] = loss_prob
|
112 |
+
|
113 |
+
loss_scalars['loss'] = loss
|
114 |
+
return loss, loss_scalars
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
def eval_batch(
|
119 |
+
batch,
|
120 |
+
output,
|
121 |
+
eval_metrics_resolution = (256, 256),
|
122 |
+
query_first = False,
|
123 |
+
):
|
124 |
+
query_points = batch['query_points']
|
125 |
+
query_points = convert_grid_coordinates(
|
126 |
+
query_points,
|
127 |
+
(1,) + batch['video'].shape[2:4], # (1, height, width)
|
128 |
+
(1,) + eval_metrics_resolution, # (1, height, width)
|
129 |
+
coordinate_format='tyx',
|
130 |
+
)
|
131 |
+
gt_target_points = batch['target_points']
|
132 |
+
gt_target_points = convert_grid_coordinates(
|
133 |
+
gt_target_points,
|
134 |
+
batch['video'].shape[3:1:-1], # (width, height)
|
135 |
+
eval_metrics_resolution[::-1], # (width, height)
|
136 |
+
coordinate_format='xy',
|
137 |
+
)
|
138 |
+
gt_occluded = batch['occluded']
|
139 |
+
|
140 |
+
tracks = output['tracks']
|
141 |
+
tracks = convert_grid_coordinates(
|
142 |
+
tracks,
|
143 |
+
batch['video'].shape[3:1:-1], # (width, height)
|
144 |
+
eval_metrics_resolution[::-1], # (width, height)
|
145 |
+
coordinate_format='xy',
|
146 |
+
)
|
147 |
+
|
148 |
+
occlusion_logits = output['occlusion']
|
149 |
+
pred_occ = torch.sigmoid(occlusion_logits)
|
150 |
+
if 'expected_dist' in output:
|
151 |
+
expected_dist = output['expected_dist']
|
152 |
+
pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(expected_dist))
|
153 |
+
pred_occ = pred_occ > 0.5 # threshold
|
154 |
+
|
155 |
+
query_mode = 'first' if query_first else 'strided'
|
156 |
+
metrics = compute_tapvid_metrics(
|
157 |
+
query_points=query_points.detach().cpu().numpy(),
|
158 |
+
gt_occluded=gt_occluded.detach().cpu().numpy(),
|
159 |
+
gt_tracks=gt_target_points.detach().cpu().numpy(),
|
160 |
+
pred_occluded=pred_occ.detach().cpu().numpy(),
|
161 |
+
pred_tracks=tracks.detach().cpu().numpy(),
|
162 |
+
query_mode=query_mode,
|
163 |
+
)
|
164 |
+
|
165 |
+
return metrics
|
locotrack_pytorch/models/cmdtop.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models import utils
|
5 |
+
|
6 |
+
|
7 |
+
class CMDTop(nn.Module):
|
8 |
+
def __init__(self, in_channel, out_channels, kernel_shapes, strides):
|
9 |
+
super(CMDTop, self).__init__()
|
10 |
+
self.in_channels = [in_channel] + list(out_channels[:-1])
|
11 |
+
self.out_channels = out_channels
|
12 |
+
self.kernel_shapes = kernel_shapes
|
13 |
+
self.strides = strides
|
14 |
+
|
15 |
+
self.conv = nn.ModuleList([
|
16 |
+
nn.Sequential(
|
17 |
+
utils.Conv2dSamePadding(
|
18 |
+
in_channels=self.in_channels[i],
|
19 |
+
out_channels=self.out_channels[i],
|
20 |
+
kernel_size=self.kernel_shapes[i],
|
21 |
+
stride=self.strides[i],
|
22 |
+
),
|
23 |
+
nn.GroupNorm(out_channels[i] // 16, out_channels[i]),
|
24 |
+
nn.ReLU()
|
25 |
+
) for i in range(len(out_channels))
|
26 |
+
])
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
"""
|
30 |
+
x: (b, h, w, i, j)
|
31 |
+
"""
|
32 |
+
out1 = utils.einshape('bhwij->b(ij)hw', x)
|
33 |
+
out2 = utils.einshape('bhwij->b(hw)ij', x)
|
34 |
+
|
35 |
+
for i in range(len(self.out_channels)):
|
36 |
+
out1 = self.conv[i](out1)
|
37 |
+
|
38 |
+
for i in range(len(self.out_channels)):
|
39 |
+
out2 = self.conv[i](out2)
|
40 |
+
|
41 |
+
out1 = torch.mean(out1, dim=(2, 3)) # (b, out_channels[-1])
|
42 |
+
out2 = torch.mean(out2, dim=(2, 3)) # (b, out_channels[-1])
|
43 |
+
|
44 |
+
return torch.cat([out1, out2], dim=-1) # (b, 2*out_channels[-1])
|
45 |
+
|
locotrack_pytorch/models/locotrack_model.py
ADDED
@@ -0,0 +1,1053 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""TAPIR models definition."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
from models import nets, utils
|
27 |
+
from models.cmdtop import CMDTop
|
28 |
+
|
29 |
+
|
30 |
+
def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
|
31 |
+
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
32 |
+
|
33 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
34 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
|
38 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
39 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
40 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
encoded: torch.Tensor, encoded variables.
|
44 |
+
"""
|
45 |
+
if min_deg == max_deg:
|
46 |
+
return x
|
47 |
+
scales = torch.tensor([2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device)
|
48 |
+
if legacy_posenc_order:
|
49 |
+
xb = x[..., None, :] * scales[:, None]
|
50 |
+
four_feat = torch.reshape(
|
51 |
+
torch.sin(torch.stack([xb, xb + 0.5 * np.pi], dim=-2)),
|
52 |
+
list(x.shape[:-1]) + [-1]
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
|
56 |
+
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
|
57 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
58 |
+
|
59 |
+
|
60 |
+
def get_relative_positions(seq_len, reverse=False):
|
61 |
+
x = torch.arange(seq_len)[None, :]
|
62 |
+
y = torch.arange(seq_len)[:, None]
|
63 |
+
return torch.tril(x - y) if not reverse else torch.triu(y - x)
|
64 |
+
|
65 |
+
|
66 |
+
def get_alibi_slope(num_heads):
|
67 |
+
x = (24) ** (1 / num_heads)
|
68 |
+
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], dtype=torch.float32).view(-1, 1, 1)
|
69 |
+
|
70 |
+
|
71 |
+
class MultiHeadAttention(nn.Module):
|
72 |
+
"""Multi-headed attention (MHA) module."""
|
73 |
+
|
74 |
+
def __init__(self, num_heads, key_size, w_init_scale=None, w_init=None, with_bias=True, b_init=None, value_size=None, model_size=None):
|
75 |
+
super(MultiHeadAttention, self).__init__()
|
76 |
+
self.num_heads = num_heads
|
77 |
+
self.key_size = key_size
|
78 |
+
self.value_size = value_size or key_size
|
79 |
+
self.model_size = model_size or key_size * num_heads
|
80 |
+
|
81 |
+
self.with_bias = with_bias
|
82 |
+
|
83 |
+
self.query_proj = nn.Linear(num_heads * key_size, num_heads * key_size, bias=with_bias)
|
84 |
+
self.key_proj = nn.Linear(num_heads * key_size, num_heads * key_size, bias=with_bias)
|
85 |
+
self.value_proj = nn.Linear(num_heads * self.value_size, num_heads * self.value_size, bias=with_bias)
|
86 |
+
self.final_proj = nn.Linear(num_heads * self.value_size, self.model_size, bias=with_bias)
|
87 |
+
|
88 |
+
def forward(self, query, key, value, mask=None):
|
89 |
+
batch_size, sequence_length, _ = query.size()
|
90 |
+
|
91 |
+
query_heads = self._linear_projection(query, self.key_size, self.query_proj) # [T', H, Q=K]
|
92 |
+
key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
|
93 |
+
value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
|
94 |
+
|
95 |
+
bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length)
|
96 |
+
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
97 |
+
bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length, reverse=True)
|
98 |
+
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
99 |
+
attn_bias = torch.cat([bias_forward, bias_backward], dim=0).to(query.device)
|
100 |
+
|
101 |
+
attn_logits = torch.einsum("...thd,...Thd->...htT", query_heads, key_heads)
|
102 |
+
attn_logits = attn_logits / np.sqrt(self.key_size) + attn_bias
|
103 |
+
|
104 |
+
if mask is not None:
|
105 |
+
if mask.ndim != attn_logits.ndim:
|
106 |
+
raise ValueError(f"Mask dimensionality {mask.ndim} must match logits dimensionality {attn_logits.ndim}.")
|
107 |
+
attn_logits = torch.where(mask, attn_logits, torch.tensor(-1e30))
|
108 |
+
|
109 |
+
attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
|
110 |
+
|
111 |
+
attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
|
112 |
+
attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
|
113 |
+
|
114 |
+
return self.final_proj(attn) # [T', D']
|
115 |
+
|
116 |
+
def _linear_projection(self, x, head_size, proj_layer):
|
117 |
+
y = proj_layer(x)
|
118 |
+
*leading_dims, _ = x.shape
|
119 |
+
return y.reshape((*leading_dims, self.num_heads, head_size))
|
120 |
+
|
121 |
+
|
122 |
+
class Transformer(nn.Module):
|
123 |
+
"""A transformer stack."""
|
124 |
+
|
125 |
+
def __init__(self, num_heads, num_layers, attn_size, dropout_rate, widening_factor=4):
|
126 |
+
super(Transformer, self).__init__()
|
127 |
+
self.num_heads = num_heads
|
128 |
+
self.num_layers = num_layers
|
129 |
+
self.attn_size = attn_size
|
130 |
+
self.dropout_rate = dropout_rate
|
131 |
+
self.widening_factor = widening_factor
|
132 |
+
|
133 |
+
self.layers = nn.ModuleList([
|
134 |
+
nn.ModuleDict({
|
135 |
+
'attn': MultiHeadAttention(num_heads, attn_size, model_size=attn_size * num_heads),
|
136 |
+
'dense': nn.Sequential(
|
137 |
+
nn.Linear(attn_size * num_heads, widening_factor * attn_size * num_heads),
|
138 |
+
nn.GELU(),
|
139 |
+
nn.Linear(widening_factor * attn_size * num_heads, attn_size * num_heads)
|
140 |
+
),
|
141 |
+
'layer_norm1': nn.LayerNorm(attn_size * num_heads),
|
142 |
+
'layer_norm2': nn.LayerNorm(attn_size * num_heads)
|
143 |
+
})
|
144 |
+
for _ in range(num_layers)
|
145 |
+
])
|
146 |
+
|
147 |
+
self.ln_out = nn.LayerNorm(attn_size * num_heads)
|
148 |
+
|
149 |
+
def forward(self, embeddings, mask=None):
|
150 |
+
h = embeddings
|
151 |
+
for layer in self.layers:
|
152 |
+
h_norm = layer['layer_norm1'](h)
|
153 |
+
h_attn = layer['attn'](h_norm, h_norm, h_norm, mask=mask)
|
154 |
+
h_attn = F.dropout(h_attn, p=self.dropout_rate, training=self.training)
|
155 |
+
h = h + h_attn
|
156 |
+
|
157 |
+
h_norm = layer['layer_norm2'](h)
|
158 |
+
h_dense = layer['dense'](h_norm)
|
159 |
+
h_dense = F.dropout(h_dense, p=self.dropout_rate, training=self.training)
|
160 |
+
h = h + h_dense
|
161 |
+
|
162 |
+
return self.ln_out(h)
|
163 |
+
|
164 |
+
|
165 |
+
class PIPSTransformer(nn.Module):
|
166 |
+
def __init__(self, input_channels, output_channels, dim=512, num_heads=8, num_layers=1):
|
167 |
+
super(PIPSTransformer, self).__init__()
|
168 |
+
self.dim = dim
|
169 |
+
|
170 |
+
self.transformer = Transformer(
|
171 |
+
num_heads=num_heads,
|
172 |
+
num_layers=num_layers,
|
173 |
+
attn_size=dim // num_heads,
|
174 |
+
dropout_rate=0.,
|
175 |
+
widening_factor=4,
|
176 |
+
)
|
177 |
+
self.input_proj = nn.Linear(input_channels, dim)
|
178 |
+
self.output_proj = nn.Linear(dim, output_channels)
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
x = self.input_proj(x)
|
182 |
+
x = self.transformer(x, mask=None)
|
183 |
+
return self.output_proj(x)
|
184 |
+
|
185 |
+
|
186 |
+
class FeatureGrids(NamedTuple):
|
187 |
+
"""Feature grids for a video, used to compute trajectories.
|
188 |
+
|
189 |
+
These are per-frame outputs of the encoding resnet.
|
190 |
+
|
191 |
+
Attributes:
|
192 |
+
lowres: Low-resolution features, one for each resolution; 256 channels.
|
193 |
+
hires: High-resolution features, one for each resolution; 64 channels.
|
194 |
+
resolutions: Resolutions used for trajectory computation. There will be one
|
195 |
+
entry for the initialization, and then an entry for each PIPs refinement
|
196 |
+
resolution.
|
197 |
+
"""
|
198 |
+
|
199 |
+
lowres: Sequence[torch.Tensor]
|
200 |
+
hires: Sequence[torch.Tensor]
|
201 |
+
highest: Sequence[torch.Tensor]
|
202 |
+
resolutions: Sequence[Tuple[int, int]]
|
203 |
+
|
204 |
+
|
205 |
+
class QueryFeatures(NamedTuple):
|
206 |
+
"""Query features used to compute trajectories.
|
207 |
+
|
208 |
+
These are sampled from the query frames and are a full descriptor of the
|
209 |
+
tracked points. They can be acquired from a query image and then reused in a
|
210 |
+
separate video.
|
211 |
+
|
212 |
+
Attributes:
|
213 |
+
lowres: Low-resolution features, one for each resolution; each has shape
|
214 |
+
[batch, num_query_points, 256]
|
215 |
+
hires: High-resolution features, one for each resolution; each has shape
|
216 |
+
[batch, num_query_points, 64]
|
217 |
+
resolutions: Resolutions used for trajectory computation. There will be one
|
218 |
+
entry for the initialization, and then an entry for each PIPs refinement
|
219 |
+
resolution.
|
220 |
+
"""
|
221 |
+
|
222 |
+
lowres: Sequence[torch.Tensor]
|
223 |
+
hires: Sequence[torch.Tensor]
|
224 |
+
highest: Sequence[torch.Tensor]
|
225 |
+
lowres_supp: Sequence[torch.Tensor]
|
226 |
+
hires_supp: Sequence[torch.Tensor]
|
227 |
+
highest_supp: Sequence[torch.Tensor]
|
228 |
+
resolutions: Sequence[Tuple[int, int]]
|
229 |
+
|
230 |
+
|
231 |
+
class LocoTrack(nn.Module):
|
232 |
+
"""TAPIR model."""
|
233 |
+
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
bilinear_interp_with_depthwise_conv: bool = False,
|
237 |
+
num_pips_iter: int = 4,
|
238 |
+
pyramid_level: int = 0,
|
239 |
+
mixer_hidden_dim: int = 512,
|
240 |
+
num_mixer_blocks: int = 12,
|
241 |
+
mixer_kernel_shape: int = 3,
|
242 |
+
patch_size: int = 7,
|
243 |
+
softmax_temperature: float = 20.0,
|
244 |
+
parallelize_query_extraction: bool = False,
|
245 |
+
initial_resolution: Tuple[int, int] = (256, 256),
|
246 |
+
blocks_per_group: Sequence[int] = (2, 2, 2, 2),
|
247 |
+
feature_extractor_chunk_size: int = 256,
|
248 |
+
extra_convs: bool = False,
|
249 |
+
use_casual_conv: bool = False,
|
250 |
+
model_size: str = 'base',
|
251 |
+
):
|
252 |
+
super().__init__()
|
253 |
+
|
254 |
+
if model_size == 'small':
|
255 |
+
model_params = {
|
256 |
+
'dim': 256,
|
257 |
+
'num_heads': 4,
|
258 |
+
'num_layers': 3,
|
259 |
+
}
|
260 |
+
cmdtop_params = {
|
261 |
+
'in_channel': 49,
|
262 |
+
'out_channels': (64, 128),
|
263 |
+
'kernel_shapes': (5, 2),
|
264 |
+
'strides': (4, 2),
|
265 |
+
}
|
266 |
+
elif model_size == 'base':
|
267 |
+
model_params = {
|
268 |
+
'dim': 384,
|
269 |
+
'num_heads': 6,
|
270 |
+
'num_layers': 3,
|
271 |
+
}
|
272 |
+
cmdtop_params = {
|
273 |
+
'in_channel': 49,
|
274 |
+
'out_channels': (64, 128, 128),
|
275 |
+
'kernel_shapes': (3, 3, 2),
|
276 |
+
'strides': (2, 2, 2),
|
277 |
+
}
|
278 |
+
else:
|
279 |
+
raise ValueError(f"Unknown model size '{model_size}'")
|
280 |
+
|
281 |
+
self.highres_dim = 128
|
282 |
+
self.lowres_dim = 256
|
283 |
+
self.bilinear_interp_with_depthwise_conv = (
|
284 |
+
bilinear_interp_with_depthwise_conv
|
285 |
+
)
|
286 |
+
self.parallelize_query_extraction = parallelize_query_extraction
|
287 |
+
|
288 |
+
self.num_pips_iter = num_pips_iter
|
289 |
+
self.pyramid_level = pyramid_level
|
290 |
+
self.patch_size = patch_size
|
291 |
+
self.softmax_temperature = softmax_temperature
|
292 |
+
self.initial_resolution = tuple(initial_resolution)
|
293 |
+
self.feature_extractor_chunk_size = feature_extractor_chunk_size
|
294 |
+
self.num_mixer_blocks = num_mixer_blocks
|
295 |
+
self.use_casual_conv = use_casual_conv
|
296 |
+
|
297 |
+
highres_dim = 128
|
298 |
+
lowres_dim = 256
|
299 |
+
strides = (1, 2, 2, 1)
|
300 |
+
blocks_per_group = (2, 2, 2, 2)
|
301 |
+
channels_per_group = (64, highres_dim, 256, lowres_dim)
|
302 |
+
use_projection = (True, True, True, True)
|
303 |
+
|
304 |
+
self.resnet_torch = nets.ResNet(
|
305 |
+
blocks_per_group=blocks_per_group,
|
306 |
+
channels_per_group=channels_per_group,
|
307 |
+
use_projection=use_projection,
|
308 |
+
strides=strides,
|
309 |
+
)
|
310 |
+
|
311 |
+
self.torch_pips_mixer = PIPSTransformer(
|
312 |
+
input_channels=854,
|
313 |
+
output_channels=4 + self.highres_dim + self.lowres_dim,
|
314 |
+
**model_params
|
315 |
+
)
|
316 |
+
|
317 |
+
self.cmdtop = nn.ModuleList([
|
318 |
+
CMDTop(
|
319 |
+
**cmdtop_params
|
320 |
+
) for _ in range(3)
|
321 |
+
])
|
322 |
+
|
323 |
+
self.cost_conv = utils.Conv2dSamePadding(2, 1, 3, 1)
|
324 |
+
self.occ_linear = nn.Linear(6, 2)
|
325 |
+
|
326 |
+
if extra_convs:
|
327 |
+
self.extra_convs = nets.ExtraConvs()
|
328 |
+
else:
|
329 |
+
self.extra_convs = None
|
330 |
+
|
331 |
+
def forward(
|
332 |
+
self,
|
333 |
+
video: torch.Tensor,
|
334 |
+
query_points: torch.Tensor,
|
335 |
+
feature_grids: Optional[FeatureGrids] = None,
|
336 |
+
is_training: bool = False,
|
337 |
+
query_chunk_size: Optional[int] = 64,
|
338 |
+
get_query_feats: bool = False,
|
339 |
+
refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
|
340 |
+
) -> Mapping[str, torch.Tensor]:
|
341 |
+
"""Runs a forward pass of the model.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
video: A 5-D tensor representing a batch of sequences of images.
|
345 |
+
query_points: The query points for which we compute tracks.
|
346 |
+
is_training: Whether we are training.
|
347 |
+
query_chunk_size: When computing cost volumes, break the queries into
|
348 |
+
chunks of this size to save memory.
|
349 |
+
get_query_feats: Return query features for other losses like contrastive.
|
350 |
+
Not supported in the current version.
|
351 |
+
refinement_resolutions: A list of (height, width) tuples. Refinement will
|
352 |
+
be repeated at each specified resolution, in order to achieve high
|
353 |
+
accuracy on resolutions higher than what TAPIR was trained on. If None,
|
354 |
+
reasonable refinement resolutions will be inferred from the input video
|
355 |
+
size.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
A dict of outputs, including:
|
359 |
+
occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
|
360 |
+
where higher indicates more likely to be occluded.
|
361 |
+
tracks: predicted point locations, of shape
|
362 |
+
[batch, num_queries, num_frames, 2], where each point is [x, y]
|
363 |
+
in raster coordinates
|
364 |
+
expected_dist: uncertainty estimate logits, of shape
|
365 |
+
[batch, num_queries, num_frames], where higher indicates more likely
|
366 |
+
to be far from the correct answer.
|
367 |
+
"""
|
368 |
+
if get_query_feats:
|
369 |
+
raise ValueError('Get query feats not supported in TAPIR.')
|
370 |
+
|
371 |
+
if feature_grids is None:
|
372 |
+
feature_grids = self.get_feature_grids(
|
373 |
+
video,
|
374 |
+
is_training,
|
375 |
+
refinement_resolutions,
|
376 |
+
)
|
377 |
+
|
378 |
+
query_features = self.get_query_features(
|
379 |
+
video,
|
380 |
+
is_training,
|
381 |
+
query_points,
|
382 |
+
feature_grids,
|
383 |
+
refinement_resolutions,
|
384 |
+
)
|
385 |
+
|
386 |
+
trajectories = self.estimate_trajectories(
|
387 |
+
video.shape[-3:-1],
|
388 |
+
is_training,
|
389 |
+
feature_grids,
|
390 |
+
query_features,
|
391 |
+
query_points,
|
392 |
+
query_chunk_size,
|
393 |
+
)
|
394 |
+
|
395 |
+
p = self.num_pips_iter
|
396 |
+
out = dict(
|
397 |
+
occlusion=torch.mean(
|
398 |
+
torch.stack(trajectories['occlusion'][p::p]), dim=0
|
399 |
+
),
|
400 |
+
tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
|
401 |
+
expected_dist=torch.mean(
|
402 |
+
torch.stack(trajectories['expected_dist'][p::p]), dim=0
|
403 |
+
),
|
404 |
+
unrefined_occlusion=trajectories['occlusion'][:-1],
|
405 |
+
unrefined_tracks=trajectories['tracks'][:-1],
|
406 |
+
unrefined_expected_dist=trajectories['expected_dist'][:-1],
|
407 |
+
)
|
408 |
+
|
409 |
+
return out
|
410 |
+
|
411 |
+
def get_query_features(
|
412 |
+
self,
|
413 |
+
video: torch.Tensor,
|
414 |
+
is_training: bool,
|
415 |
+
query_points: torch.Tensor,
|
416 |
+
feature_grids: Optional[FeatureGrids] = None,
|
417 |
+
refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
|
418 |
+
) -> QueryFeatures:
|
419 |
+
"""Computes query features, which can be used for estimate_trajectories.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
video: A 5-D tensor representing a batch of sequences of images.
|
423 |
+
is_training: Whether we are training.
|
424 |
+
query_points: The query points for which we compute tracks.
|
425 |
+
feature_grids: If passed, we'll use these feature grids rather than
|
426 |
+
computing new ones.
|
427 |
+
refinement_resolutions: A list of (height, width) tuples. Refinement will
|
428 |
+
be repeated at each specified resolution, in order to achieve high
|
429 |
+
accuracy on resolutions higher than what TAPIR was trained on. If None,
|
430 |
+
reasonable refinement resolutions will be inferred from the input video
|
431 |
+
size.
|
432 |
+
|
433 |
+
Returns:
|
434 |
+
A QueryFeatures object which contains the required features for every
|
435 |
+
required resolution.
|
436 |
+
"""
|
437 |
+
|
438 |
+
if feature_grids is None:
|
439 |
+
feature_grids = self.get_feature_grids(
|
440 |
+
video,
|
441 |
+
is_training=is_training,
|
442 |
+
refinement_resolutions=refinement_resolutions,
|
443 |
+
)
|
444 |
+
|
445 |
+
feature_grid = feature_grids.lowres
|
446 |
+
hires_feats = feature_grids.hires
|
447 |
+
highest_feats = feature_grids.highest
|
448 |
+
resize_im_shape = feature_grids.resolutions
|
449 |
+
|
450 |
+
shape = video.shape
|
451 |
+
# shape is [batch_size, time, height, width, channels]; conversion needs
|
452 |
+
# [time, width, height]
|
453 |
+
curr_resolution = (-1, -1)
|
454 |
+
query_feats = []
|
455 |
+
hires_query_feats = []
|
456 |
+
highest_query_feats = []
|
457 |
+
query_supp = []
|
458 |
+
hires_query_supp = []
|
459 |
+
highest_query_supp = []
|
460 |
+
for i, resolution in enumerate(resize_im_shape):
|
461 |
+
if utils.is_same_res(curr_resolution, resolution):
|
462 |
+
query_feats.append(query_feats[-1])
|
463 |
+
hires_query_feats.append(hires_query_feats[-1])
|
464 |
+
highest_query_feats.append(highest_query_feats[-1])
|
465 |
+
query_supp.append(query_supp[-1])
|
466 |
+
hires_query_supp.append(hires_query_supp[-1])
|
467 |
+
highest_query_supp.append(highest_query_supp[-1])
|
468 |
+
continue
|
469 |
+
position_in_grid = utils.convert_grid_coordinates(
|
470 |
+
query_points,
|
471 |
+
shape[1:4],
|
472 |
+
feature_grid[i].shape[1:4],
|
473 |
+
coordinate_format='tyx',
|
474 |
+
)
|
475 |
+
position_in_grid_hires = utils.convert_grid_coordinates(
|
476 |
+
query_points,
|
477 |
+
shape[1:4],
|
478 |
+
hires_feats[i].shape[1:4],
|
479 |
+
coordinate_format='tyx',
|
480 |
+
)
|
481 |
+
position_in_grid_highest = utils.convert_grid_coordinates(
|
482 |
+
query_points,
|
483 |
+
shape[1:4],
|
484 |
+
highest_feats[i].shape[1:4],
|
485 |
+
coordinate_format='tyx',
|
486 |
+
)
|
487 |
+
|
488 |
+
support_size = 7
|
489 |
+
ctxx, ctxy = torch.meshgrid(
|
490 |
+
torch.arange(-(support_size // 2), support_size // 2 + 1),
|
491 |
+
torch.arange(-(support_size // 2), support_size // 2 + 1),
|
492 |
+
indexing='xy',
|
493 |
+
)
|
494 |
+
ctx = torch.stack([torch.zeros_like(ctxy), ctxy, ctxx], axis=-1)
|
495 |
+
ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
|
496 |
+
|
497 |
+
position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
|
498 |
+
position_support = utils.einshape('bnsc->b(ns)c', position_support)
|
499 |
+
interp_supp = utils.map_coordinates_3d(
|
500 |
+
feature_grid[i], position_support
|
501 |
+
)
|
502 |
+
interp_supp = utils.einshape('b(nhw)c->bnhwc', interp_supp, h=support_size, w=support_size)
|
503 |
+
|
504 |
+
position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
|
505 |
+
position_support_hires = utils.einshape('bnsc->b(ns)c', position_support_hires)
|
506 |
+
hires_interp_supp = utils.map_coordinates_3d(
|
507 |
+
hires_feats[i], position_support_hires
|
508 |
+
)
|
509 |
+
hires_interp_supp = utils.einshape('b(nhw)c->bnhwc', hires_interp_supp, h=support_size, w=support_size)
|
510 |
+
|
511 |
+
position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
|
512 |
+
position_support_highest = utils.einshape('bnsc->b(ns)c', position_support_highest)
|
513 |
+
highest_interp_supp = utils.map_coordinates_3d(
|
514 |
+
highest_feats[i], position_support_highest
|
515 |
+
)
|
516 |
+
highest_interp_supp = utils.einshape('b(nhw)c->bnhwc', highest_interp_supp, h=support_size, w=support_size)
|
517 |
+
|
518 |
+
interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
|
519 |
+
hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
|
520 |
+
highest_interp = highest_interp_supp[..., support_size // 2, support_size // 2, :]
|
521 |
+
|
522 |
+
hires_query_feats.append(hires_interp)
|
523 |
+
query_feats.append(interp_features)
|
524 |
+
highest_query_feats.append(highest_interp)
|
525 |
+
query_supp.append(interp_supp)
|
526 |
+
hires_query_supp.append(hires_interp_supp)
|
527 |
+
highest_query_supp.append(highest_interp_supp)
|
528 |
+
|
529 |
+
return QueryFeatures(
|
530 |
+
tuple(query_feats), tuple(hires_query_feats), tuple(highest_query_feats),
|
531 |
+
tuple(query_supp), tuple(hires_query_supp), tuple(highest_query_supp), tuple(resize_im_shape),
|
532 |
+
)
|
533 |
+
|
534 |
+
def get_feature_grids(
|
535 |
+
self,
|
536 |
+
video: torch.Tensor,
|
537 |
+
is_training: Optional[bool] = False,
|
538 |
+
refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
|
539 |
+
) -> FeatureGrids:
|
540 |
+
"""Computes feature grids.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
video: A 5-D tensor representing a batch of sequences of images.
|
544 |
+
is_training: Whether we are training.
|
545 |
+
refinement_resolutions: A list of (height, width) tuples. Refinement will
|
546 |
+
be repeated at each specified resolution, to achieve high accuracy on
|
547 |
+
resolutions higher than what TAPIR was trained on. If None, reasonable
|
548 |
+
refinement resolutions will be inferred from the input video size.
|
549 |
+
|
550 |
+
Returns:
|
551 |
+
A FeatureGrids object containing the required features for every
|
552 |
+
required resolution. Note that there will be one more feature grid
|
553 |
+
than there are refinement_resolutions, because there is always a
|
554 |
+
feature grid computed for TAP-Net initialization.
|
555 |
+
"""
|
556 |
+
del is_training
|
557 |
+
if refinement_resolutions is None:
|
558 |
+
refinement_resolutions = utils.generate_default_resolutions(
|
559 |
+
video.shape[2:4], self.initial_resolution
|
560 |
+
)
|
561 |
+
|
562 |
+
all_required_resolutions = [self.initial_resolution]
|
563 |
+
all_required_resolutions.extend(refinement_resolutions)
|
564 |
+
|
565 |
+
feature_grid = []
|
566 |
+
hires_feats = []
|
567 |
+
highest_feats = []
|
568 |
+
resize_im_shape = []
|
569 |
+
curr_resolution = (-1, -1)
|
570 |
+
|
571 |
+
latent = None
|
572 |
+
hires = None
|
573 |
+
video_resize = None
|
574 |
+
for resolution in all_required_resolutions:
|
575 |
+
if resolution[0] % 8 != 0 or resolution[1] % 8 != 0:
|
576 |
+
raise ValueError('Image resolution must be a multiple of 8.')
|
577 |
+
|
578 |
+
if not utils.is_same_res(curr_resolution, resolution):
|
579 |
+
if utils.is_same_res(curr_resolution, video.shape[-3:-1]):
|
580 |
+
video_resize = video
|
581 |
+
else:
|
582 |
+
video_resize = utils.bilinear(video, resolution)
|
583 |
+
|
584 |
+
curr_resolution = resolution
|
585 |
+
n, f, h, w, c = video_resize.shape
|
586 |
+
video_resize = video_resize.view(n*f, h, w, c).permute(0, 3, 1, 2)
|
587 |
+
|
588 |
+
if self.feature_extractor_chunk_size > 0:
|
589 |
+
latent_list = []
|
590 |
+
hires_list = []
|
591 |
+
highest_list = []
|
592 |
+
chunk_size = self.feature_extractor_chunk_size
|
593 |
+
for start_idx in range(0, video_resize.shape[0], chunk_size):
|
594 |
+
video_chunk = video_resize[start_idx:start_idx + chunk_size]
|
595 |
+
resnet_out = self.resnet_torch(video_chunk)
|
596 |
+
|
597 |
+
u3 = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1)
|
598 |
+
latent_list.append(u3)
|
599 |
+
u1 = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1)
|
600 |
+
hires_list.append(u1)
|
601 |
+
u0 = resnet_out['resnet_unit_0'].permute(0, 2, 3, 1)
|
602 |
+
highest_list.append(u0)
|
603 |
+
|
604 |
+
latent = torch.cat(latent_list, dim=0)
|
605 |
+
hires = torch.cat(hires_list, dim=0)
|
606 |
+
highest = torch.cat(highest_list, dim=0)
|
607 |
+
|
608 |
+
else:
|
609 |
+
resnet_out = self.resnet_torch(video_resize)
|
610 |
+
latent = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1)
|
611 |
+
hires = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1)
|
612 |
+
highest = resnet_out['resnet_unit_0'].permute(0, 2, 3, 1)
|
613 |
+
|
614 |
+
if self.extra_convs:
|
615 |
+
latent = self.extra_convs(latent)
|
616 |
+
|
617 |
+
latent = latent / torch.sqrt(
|
618 |
+
torch.maximum(
|
619 |
+
torch.sum(torch.square(latent), axis=-1, keepdims=True),
|
620 |
+
torch.tensor(1e-12, device=latent.device),
|
621 |
+
)
|
622 |
+
)
|
623 |
+
hires = hires / torch.sqrt(
|
624 |
+
torch.maximum(
|
625 |
+
torch.sum(torch.square(hires), axis=-1, keepdims=True),
|
626 |
+
torch.tensor(1e-12, device=hires.device),
|
627 |
+
)
|
628 |
+
)
|
629 |
+
highest = highest / torch.sqrt(
|
630 |
+
torch.maximum(
|
631 |
+
torch.sum(torch.square(highest), axis=-1, keepdims=True),
|
632 |
+
torch.tensor(1e-12, device=highest.device),
|
633 |
+
)
|
634 |
+
)
|
635 |
+
|
636 |
+
latent = latent.view(n, f, *latent.shape[1:])
|
637 |
+
hires = hires.view(n, f, *hires.shape[1:])
|
638 |
+
highest = highest.view(n, f, *highest.shape[1:])
|
639 |
+
|
640 |
+
feature_grid.append(latent)
|
641 |
+
hires_feats.append(hires)
|
642 |
+
highest_feats.append(highest)
|
643 |
+
resize_im_shape.append(video_resize.shape[2:4])
|
644 |
+
|
645 |
+
return FeatureGrids(
|
646 |
+
tuple(feature_grid), tuple(hires_feats), tuple(highest_feats), tuple(resize_im_shape)
|
647 |
+
)
|
648 |
+
|
649 |
+
def estimate_trajectories(
|
650 |
+
self,
|
651 |
+
video_size: Tuple[int, int],
|
652 |
+
is_training: bool,
|
653 |
+
feature_grids: FeatureGrids,
|
654 |
+
query_features: QueryFeatures,
|
655 |
+
query_points_in_video: Optional[torch.Tensor],
|
656 |
+
query_chunk_size: Optional[int] = None,
|
657 |
+
causal_context: Optional[dict[str, torch.Tensor]] = None,
|
658 |
+
get_causal_context: bool = False,
|
659 |
+
) -> Mapping[str, Any]:
|
660 |
+
"""Estimates trajectories given features for a video and query features.
|
661 |
+
|
662 |
+
Args:
|
663 |
+
video_size: A 2-tuple containing the original [height, width] of the
|
664 |
+
video. Predictions will be scaled with respect to this resolution.
|
665 |
+
is_training: Whether we are training.
|
666 |
+
feature_grids: a FeatureGrids object computed for the given video.
|
667 |
+
query_features: a QueryFeatures object computed for the query points.
|
668 |
+
query_points_in_video: If provided, assume that the query points come from
|
669 |
+
the same video as feature_grids, and therefore constrain the resulting
|
670 |
+
trajectories to (approximately) pass through them.
|
671 |
+
query_chunk_size: When computing cost volumes, break the queries into
|
672 |
+
chunks of this size to save memory.
|
673 |
+
causal_context: If provided, a dict of causal context to use for
|
674 |
+
refinement.
|
675 |
+
get_causal_context: If True, return causal context in the output.
|
676 |
+
|
677 |
+
Returns:
|
678 |
+
A dict of outputs, including:
|
679 |
+
occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
|
680 |
+
where higher indicates more likely to be occluded.
|
681 |
+
tracks: predicted point locations, of shape
|
682 |
+
[batch, num_queries, num_frames, 2], where each point is [x, y]
|
683 |
+
in raster coordinates
|
684 |
+
expected_dist: uncertainty estimate logits, of shape
|
685 |
+
[batch, num_queries, num_frames], where higher indicates more likely
|
686 |
+
to be far from the correct answer.
|
687 |
+
"""
|
688 |
+
del is_training
|
689 |
+
|
690 |
+
def train2orig(x):
|
691 |
+
return utils.convert_grid_coordinates(
|
692 |
+
x,
|
693 |
+
self.initial_resolution[::-1],
|
694 |
+
video_size[::-1],
|
695 |
+
coordinate_format='xy',
|
696 |
+
)
|
697 |
+
|
698 |
+
occ_iters = []
|
699 |
+
pts_iters = []
|
700 |
+
expd_iters = []
|
701 |
+
new_causal_context = []
|
702 |
+
num_iters = self.num_pips_iter * (len(feature_grids.lowres) - 1)
|
703 |
+
for _ in range(num_iters + 1):
|
704 |
+
occ_iters.append([])
|
705 |
+
pts_iters.append([])
|
706 |
+
expd_iters.append([])
|
707 |
+
new_causal_context.append([])
|
708 |
+
del new_causal_context[-1]
|
709 |
+
|
710 |
+
infer = functools.partial(
|
711 |
+
self.tracks_from_cost_volume,
|
712 |
+
im_shp=feature_grids.lowres[0].shape[0:2]
|
713 |
+
+ self.initial_resolution
|
714 |
+
+ (3,),
|
715 |
+
)
|
716 |
+
|
717 |
+
num_queries = query_features.lowres[0].shape[1]
|
718 |
+
if causal_context is None:
|
719 |
+
perm = torch.randperm(num_queries)
|
720 |
+
else:
|
721 |
+
perm = torch.arange(num_queries)
|
722 |
+
|
723 |
+
inv_perm = torch.zeros_like(perm)
|
724 |
+
inv_perm[perm] = torch.arange(num_queries)
|
725 |
+
|
726 |
+
for ch in range(0, num_queries, query_chunk_size):
|
727 |
+
perm_chunk = perm[ch : ch + query_chunk_size]
|
728 |
+
chunk = query_features.lowres[0][:, perm_chunk]
|
729 |
+
chunk_hires = query_features.hires[0][:, perm_chunk]
|
730 |
+
|
731 |
+
cc_chunk = []
|
732 |
+
if causal_context is not None:
|
733 |
+
for d in range(len(causal_context)):
|
734 |
+
tmp_dict = {}
|
735 |
+
for k, v in causal_context[d].items():
|
736 |
+
tmp_dict[k] = v[:, perm_chunk]
|
737 |
+
cc_chunk.append(tmp_dict)
|
738 |
+
|
739 |
+
if query_points_in_video is not None:
|
740 |
+
infer_query_points = query_points_in_video[
|
741 |
+
:, perm[ch : ch + query_chunk_size]
|
742 |
+
]
|
743 |
+
num_frames = feature_grids.lowres[0].shape[1]
|
744 |
+
infer_query_points = utils.convert_grid_coordinates(
|
745 |
+
infer_query_points,
|
746 |
+
(num_frames,) + video_size,
|
747 |
+
(num_frames,) + self.initial_resolution,
|
748 |
+
coordinate_format='tyx',
|
749 |
+
)
|
750 |
+
else:
|
751 |
+
infer_query_points = None
|
752 |
+
|
753 |
+
points, occlusion, expected_dist, cost_volume = infer(
|
754 |
+
chunk,
|
755 |
+
chunk_hires,
|
756 |
+
feature_grids.lowres[0],
|
757 |
+
feature_grids.hires[0],
|
758 |
+
infer_query_points,
|
759 |
+
)
|
760 |
+
pts_iters[0].append(train2orig(points))
|
761 |
+
occ_iters[0].append(occlusion)
|
762 |
+
expd_iters[0].append(expected_dist)
|
763 |
+
|
764 |
+
mixer_feats = None
|
765 |
+
for i in range(num_iters):
|
766 |
+
feature_level = i // self.num_pips_iter + 1
|
767 |
+
queries = [
|
768 |
+
query_features.hires[feature_level][:, perm_chunk],
|
769 |
+
query_features.lowres[feature_level][:, perm_chunk],
|
770 |
+
query_features.highest[feature_level][:, perm_chunk],
|
771 |
+
]
|
772 |
+
supports = [
|
773 |
+
query_features.hires_supp[feature_level][:, perm_chunk],
|
774 |
+
query_features.lowres_supp[feature_level][:, perm_chunk],
|
775 |
+
query_features.highest_supp[feature_level][:, perm_chunk],
|
776 |
+
]
|
777 |
+
for _ in range(self.pyramid_level):
|
778 |
+
queries.append(queries[-1])
|
779 |
+
pyramid = [
|
780 |
+
feature_grids.hires[feature_level],
|
781 |
+
feature_grids.lowres[feature_level],
|
782 |
+
feature_grids.highest[feature_level],
|
783 |
+
]
|
784 |
+
for _ in range(self.pyramid_level):
|
785 |
+
pyramid.append(
|
786 |
+
F.avg_pool3d(
|
787 |
+
pyramid[-1],
|
788 |
+
kernel_size=(2, 2, 1),
|
789 |
+
stride=(2, 2, 1),
|
790 |
+
padding=0,
|
791 |
+
)
|
792 |
+
)
|
793 |
+
cc = cc_chunk[i] if causal_context is not None else None
|
794 |
+
refined = self.refine_pips(
|
795 |
+
queries,
|
796 |
+
supports,
|
797 |
+
None,
|
798 |
+
pyramid,
|
799 |
+
points.detach(),
|
800 |
+
occlusion.detach(),
|
801 |
+
expected_dist.detach(),
|
802 |
+
orig_hw=self.initial_resolution,
|
803 |
+
last_iter=mixer_feats,
|
804 |
+
mixer_iter=i,
|
805 |
+
resize_hw=feature_grids.resolutions[feature_level],
|
806 |
+
causal_context=cc,
|
807 |
+
get_causal_context=get_causal_context,
|
808 |
+
cost_volume=cost_volume
|
809 |
+
)
|
810 |
+
points, occlusion, expected_dist, mixer_feats, new_causal = refined
|
811 |
+
pts_iters[i + 1].append(train2orig(points))
|
812 |
+
occ_iters[i + 1].append(occlusion)
|
813 |
+
expd_iters[i + 1].append(expected_dist)
|
814 |
+
new_causal_context[i].append(new_causal)
|
815 |
+
|
816 |
+
if (i + 1) % self.num_pips_iter == 0:
|
817 |
+
mixer_feats = None
|
818 |
+
expected_dist = expd_iters[0][-1]
|
819 |
+
occlusion = occ_iters[0][-1]
|
820 |
+
|
821 |
+
occlusion = []
|
822 |
+
points = []
|
823 |
+
expd = []
|
824 |
+
for i, _ in enumerate(occ_iters):
|
825 |
+
occlusion.append(torch.cat(occ_iters[i], dim=1)[:, inv_perm])
|
826 |
+
points.append(torch.cat(pts_iters[i], dim=1)[:, inv_perm])
|
827 |
+
expd.append(torch.cat(expd_iters[i], dim=1)[:, inv_perm])
|
828 |
+
|
829 |
+
out = dict(
|
830 |
+
occlusion=occlusion,
|
831 |
+
tracks=points,
|
832 |
+
expected_dist=expd,
|
833 |
+
)
|
834 |
+
return out
|
835 |
+
|
836 |
+
def refine_pips(
|
837 |
+
self,
|
838 |
+
target_feature,
|
839 |
+
support_feature,
|
840 |
+
frame_features,
|
841 |
+
pyramid,
|
842 |
+
pos_guess,
|
843 |
+
occ_guess,
|
844 |
+
expd_guess,
|
845 |
+
orig_hw,
|
846 |
+
last_iter=None,
|
847 |
+
mixer_iter=0.0,
|
848 |
+
resize_hw=None,
|
849 |
+
causal_context=None,
|
850 |
+
get_causal_context=False,
|
851 |
+
cost_volume=None,
|
852 |
+
):
|
853 |
+
del frame_features
|
854 |
+
del mixer_iter
|
855 |
+
orig_h, orig_w = orig_hw
|
856 |
+
resized_h, resized_w = resize_hw
|
857 |
+
corrs_pyr = []
|
858 |
+
assert len(target_feature) == len(pyramid)
|
859 |
+
for pyridx, (query, supp, grid) in enumerate(zip(target_feature, support_feature, pyramid)):
|
860 |
+
# note: interp needs [y,x]
|
861 |
+
coords = utils.convert_grid_coordinates(
|
862 |
+
pos_guess, (orig_w, orig_h), grid.shape[-2:-4:-1]
|
863 |
+
)
|
864 |
+
coords = torch.flip(coords, dims=(-1,))
|
865 |
+
|
866 |
+
support_size = 7
|
867 |
+
ctxx, ctxy = torch.meshgrid(
|
868 |
+
torch.arange(-(support_size // 2), support_size // 2 + 1),
|
869 |
+
torch.arange(-(support_size // 2), support_size // 2 + 1),
|
870 |
+
indexing='xy',
|
871 |
+
)
|
872 |
+
ctx = torch.stack([ctxy, ctxx], dim=-1)
|
873 |
+
ctx = ctx.reshape(-1, 2).to(coords.device)
|
874 |
+
coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
875 |
+
neighborhood = utils.map_coordinates_2d(grid, coords2)
|
876 |
+
|
877 |
+
neighborhood = utils.einshape('bnt(hw)c->bnthwc', neighborhood, h=support_size, w=support_size)
|
878 |
+
patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
|
879 |
+
patches_input = utils.einshape('bnthwij->(bnt)hwij', patches_input)
|
880 |
+
patches_emb = self.cmdtop[pyridx](patches_input)
|
881 |
+
patches = utils.einshape('(bnt)c->bntc', patches_emb, b=neighborhood.shape[0], n=neighborhood.shape[1])
|
882 |
+
|
883 |
+
corrs_pyr.append(patches)
|
884 |
+
corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
|
885 |
+
|
886 |
+
corrs_chunked = corrs_pyr
|
887 |
+
pos_guess_input = pos_guess
|
888 |
+
occ_guess_input = occ_guess[..., None]
|
889 |
+
expd_guess_input = expd_guess[..., None]
|
890 |
+
|
891 |
+
# mlp_input is batch, num_points, num_chunks, frames_per_chunk, channels
|
892 |
+
if last_iter is None:
|
893 |
+
both_feature = torch.cat([target_feature[0], target_feature[1]], axis=-1)
|
894 |
+
mlp_input_features = torch.tile(
|
895 |
+
both_feature.unsqueeze(2), (1, 1, corrs_chunked.shape[-2], 1)
|
896 |
+
)
|
897 |
+
else:
|
898 |
+
mlp_input_features = last_iter
|
899 |
+
|
900 |
+
mlp_input_list = [
|
901 |
+
occ_guess_input,
|
902 |
+
expd_guess_input,
|
903 |
+
corrs_chunked
|
904 |
+
]
|
905 |
+
|
906 |
+
rel_pos_forward = F.pad(pos_guess_input[..., :-1, :] - pos_guess_input[..., 1:, :], (0, 0, 0, 1))
|
907 |
+
rel_pos_backward = F.pad(pos_guess_input[..., 1:, :] - pos_guess_input[..., :-1, :], (0, 0, 1, 0))
|
908 |
+
scale = torch.tensor([resized_w / orig_w, resized_h / orig_h]) / torch.tensor([orig_w, orig_h])
|
909 |
+
scale = scale.to(pos_guess_input.device)
|
910 |
+
rel_pos_forward = rel_pos_forward * scale
|
911 |
+
rel_pos_backward = rel_pos_backward * scale
|
912 |
+
rel_pos_emb_input = posenc(torch.cat([rel_pos_forward, rel_pos_backward], axis=-1), min_deg=0, max_deg=10) # batch, num_points, num_frames, 84
|
913 |
+
mlp_input_list.append(rel_pos_emb_input)
|
914 |
+
mlp_input = torch.cat(mlp_input_list, axis=-1)
|
915 |
+
|
916 |
+
x = utils.einshape('bnfc->(bn)fc', mlp_input)
|
917 |
+
|
918 |
+
if causal_context is not None:
|
919 |
+
for k, v in causal_context.items():
|
920 |
+
causal_context[k] = utils.einshape('bn...->(bn)...', v)
|
921 |
+
res = self.torch_pips_mixer(x)
|
922 |
+
|
923 |
+
res = utils.einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0])
|
924 |
+
|
925 |
+
pos_update = utils.convert_grid_coordinates(
|
926 |
+
res[..., :2],
|
927 |
+
(resized_w, resized_h),
|
928 |
+
(orig_w, orig_h),
|
929 |
+
)
|
930 |
+
return (
|
931 |
+
pos_update + pos_guess,
|
932 |
+
res[..., 2] + occ_guess,
|
933 |
+
res[..., 3] + expd_guess,
|
934 |
+
res[..., 4:] + (mlp_input_features if last_iter is None else last_iter),
|
935 |
+
None,
|
936 |
+
)
|
937 |
+
|
938 |
+
def tracks_from_cost_volume(
|
939 |
+
self,
|
940 |
+
interp_feature: torch.Tensor,
|
941 |
+
interp_feature_hires: torch.Tensor,
|
942 |
+
feature_grid: torch.Tensor,
|
943 |
+
feature_grid_hires: torch.Tensor,
|
944 |
+
query_points: Optional[torch.Tensor],
|
945 |
+
im_shp=None,
|
946 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
947 |
+
"""Converts features into tracks by computing a cost volume.
|
948 |
+
|
949 |
+
The computed cost volume will have shape
|
950 |
+
[batch, num_queries, time, height, width], which can be very
|
951 |
+
memory intensive.
|
952 |
+
|
953 |
+
Args:
|
954 |
+
interp_feature: A tensor of features for each query point, of shape
|
955 |
+
[batch, num_queries, channels, heads].
|
956 |
+
feature_grid: A tensor of features for the video, of shape [batch, time,
|
957 |
+
height, width, channels, heads].
|
958 |
+
query_points: When computing tracks, we assume these points are given as
|
959 |
+
ground truth and we reproduce them exactly. This is a set of points of
|
960 |
+
shape [batch, num_points, 3], where each entry is [t, y, x] in frame/
|
961 |
+
raster coordinates.
|
962 |
+
im_shp: The shape of the original image, i.e., [batch, num_frames, time,
|
963 |
+
height, width, 3].
|
964 |
+
|
965 |
+
Returns:
|
966 |
+
A 2-tuple of the inferred points (of shape
|
967 |
+
[batch, num_points, num_frames, 2] where each point is [x, y]) and
|
968 |
+
inferred occlusion (of shape [batch, num_points, num_frames], where
|
969 |
+
each is a logit where higher means occluded)
|
970 |
+
"""
|
971 |
+
|
972 |
+
cost_volume = torch.einsum(
|
973 |
+
'bnc,bthwc->tbnhw',
|
974 |
+
interp_feature,
|
975 |
+
feature_grid,
|
976 |
+
)
|
977 |
+
cost_volume_hires = torch.einsum(
|
978 |
+
'bnc,bthwc->tbnhw',
|
979 |
+
interp_feature_hires,
|
980 |
+
feature_grid_hires,
|
981 |
+
)
|
982 |
+
|
983 |
+
shape = cost_volume.shape
|
984 |
+
batch_size, num_points = cost_volume.shape[1:3]
|
985 |
+
|
986 |
+
interp_cost = utils.einshape('tbnhw->(tbn)1hw', cost_volume)
|
987 |
+
interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
|
988 |
+
# TODO: not sure if this is correct
|
989 |
+
interp_cost = utils.einshape('(tbn)1hw->tbnhw', interp_cost, b=batch_size, n=num_points)
|
990 |
+
cost_volume_stack = torch.stack(
|
991 |
+
[
|
992 |
+
# jax.image.resize(cost_volume, cost_volume_hires.shape, method='bilinear'),
|
993 |
+
interp_cost,
|
994 |
+
cost_volume_hires,
|
995 |
+
], dim=-1
|
996 |
+
)
|
997 |
+
pos = utils.einshape('tbnhwc->(tbn)chw', cost_volume_stack)
|
998 |
+
pos = self.cost_conv(pos)
|
999 |
+
pos = utils.einshape('(tbn)1hw->bnthw', pos, b=batch_size, n=num_points)
|
1000 |
+
|
1001 |
+
pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
|
1002 |
+
softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
|
1003 |
+
pos = softmaxed.view_as(pos)
|
1004 |
+
|
1005 |
+
points = utils.heatmaps_to_points(pos, im_shp, query_points=query_points)
|
1006 |
+
|
1007 |
+
occlusion = torch.cat(
|
1008 |
+
[
|
1009 |
+
torch.mean(cost_volume_stack, dim=(-2, -3)),
|
1010 |
+
torch.amax(cost_volume_stack, dim=(-2, -3)),
|
1011 |
+
torch.amin(cost_volume_stack, dim=(-2, -3)),
|
1012 |
+
], dim=-1
|
1013 |
+
)
|
1014 |
+
occlusion = self.occ_linear(occlusion)
|
1015 |
+
expected_dist = utils.einshape(
|
1016 |
+
'tbn1->bnt', occlusion[..., 1:2]
|
1017 |
+
)
|
1018 |
+
occlusion = utils.einshape(
|
1019 |
+
'tbn1->bnt', occlusion[..., 0:1]
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
return points, occlusion, expected_dist, utils.einshape('tbnhw->bnthw', cost_volume)
|
1023 |
+
|
1024 |
+
def construct_initial_causal_state(self, num_points, num_resolutions=1):
|
1025 |
+
"""Construct initial causal state."""
|
1026 |
+
value_shapes = {}
|
1027 |
+
for i in range(self.num_mixer_blocks):
|
1028 |
+
value_shapes[f'block_{i}_causal_1'] = (1, num_points, 2, 512)
|
1029 |
+
value_shapes[f'block_{i}_causal_2'] = (1, num_points, 2, 2048)
|
1030 |
+
fake_ret = {
|
1031 |
+
k: torch.zeros(v, dtype=torch.float32) for k, v in value_shapes.items()
|
1032 |
+
}
|
1033 |
+
return [fake_ret] * num_resolutions * 4
|
1034 |
+
|
1035 |
+
|
1036 |
+
CHECKPOINT_LINK = {
|
1037 |
+
'small': 'https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_small.ckpt',
|
1038 |
+
'base': 'https://huggingface.co/datasets/hamacojr/LocoTrack-pytorch-weights/resolve/main/locotrack_base.ckpt',
|
1039 |
+
}
|
1040 |
+
|
1041 |
+
def load_model(ckpt_path=None, model_size='base'):
|
1042 |
+
if ckpt_path is None:
|
1043 |
+
ckpt_link = CHECKPOINT_LINK[model_size]
|
1044 |
+
state_dict = torch.hub.load_state_dict_from_url(ckpt_link, map_location='cpu')['state_dict']
|
1045 |
+
else:
|
1046 |
+
state_dict = torch.load(ckpt_path)['state_dict']
|
1047 |
+
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
|
1048 |
+
|
1049 |
+
model = LocoTrack(model_size=model_size)
|
1050 |
+
model.load_state_dict(state_dict)
|
1051 |
+
model.eval()
|
1052 |
+
|
1053 |
+
return model
|
locotrack_pytorch/models/nets.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Pytorch neural network definitions."""
|
17 |
+
|
18 |
+
from typing import Sequence, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from models.utils import Conv2dSamePadding
|
25 |
+
|
26 |
+
|
27 |
+
class ExtraConvBlock(nn.Module):
|
28 |
+
"""Additional convolution block."""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
channel_dim,
|
33 |
+
channel_multiplier,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.channel_dim = channel_dim
|
37 |
+
self.channel_multiplier = channel_multiplier
|
38 |
+
|
39 |
+
self.layer_norm = nn.LayerNorm(
|
40 |
+
normalized_shape=channel_dim, elementwise_affine=True, bias=True
|
41 |
+
)
|
42 |
+
self.conv = nn.Conv2d(
|
43 |
+
self.channel_dim,
|
44 |
+
self.channel_dim * self.channel_multiplier,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1,
|
48 |
+
)
|
49 |
+
self.conv_1 = nn.Conv2d(
|
50 |
+
self.channel_dim * self.channel_multiplier,
|
51 |
+
self.channel_dim,
|
52 |
+
kernel_size=3,
|
53 |
+
stride=1,
|
54 |
+
padding=1,
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.layer_norm(x)
|
59 |
+
x = x.permute(0, 3, 1, 2)
|
60 |
+
res = self.conv(x)
|
61 |
+
res = F.gelu(res, approximate='tanh')
|
62 |
+
x = x + self.conv_1(res)
|
63 |
+
x = x.permute(0, 2, 3, 1)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class ExtraConvs(nn.Module):
|
68 |
+
"""Additional CNN."""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
num_layers=5,
|
73 |
+
channel_dim=256,
|
74 |
+
channel_multiplier=4,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.num_layers = num_layers
|
78 |
+
self.channel_dim = channel_dim
|
79 |
+
self.channel_multiplier = channel_multiplier
|
80 |
+
|
81 |
+
self.blocks = nn.ModuleList()
|
82 |
+
for _ in range(self.num_layers):
|
83 |
+
self.blocks.append(
|
84 |
+
ExtraConvBlock(self.channel_dim, self.channel_multiplier)
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
for block in self.blocks:
|
89 |
+
x = block(x)
|
90 |
+
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
class ConvChannelsMixer(nn.Module):
|
95 |
+
"""Linear activation block for PIPs's MLP Mixer."""
|
96 |
+
|
97 |
+
def __init__(self, in_channels):
|
98 |
+
super().__init__()
|
99 |
+
self.mlp2_up = nn.Linear(in_channels, in_channels * 4)
|
100 |
+
self.mlp2_down = nn.Linear(in_channels * 4, in_channels)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
x = self.mlp2_up(x)
|
104 |
+
x = F.gelu(x, approximate='tanh')
|
105 |
+
x = self.mlp2_down(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class PIPsConvBlock(nn.Module):
|
110 |
+
"""Convolutional block for PIPs's MLP Mixer."""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self, in_channels, kernel_shape=3, use_causal_conv=False, block_idx=None
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.use_causal_conv = use_causal_conv
|
117 |
+
self.block_name = f'block_{block_idx}'
|
118 |
+
self.kernel_shape = kernel_shape
|
119 |
+
|
120 |
+
self.layer_norm = nn.LayerNorm(
|
121 |
+
normalized_shape=in_channels, elementwise_affine=True, bias=False
|
122 |
+
)
|
123 |
+
self.mlp1_up = nn.Conv1d(
|
124 |
+
in_channels,
|
125 |
+
in_channels * 4,
|
126 |
+
kernel_shape,
|
127 |
+
stride=1,
|
128 |
+
padding=0 if self.use_causal_conv else 1,
|
129 |
+
groups=in_channels,
|
130 |
+
)
|
131 |
+
|
132 |
+
self.mlp1_up_1 = nn.Conv1d(
|
133 |
+
in_channels * 4,
|
134 |
+
in_channels * 4,
|
135 |
+
kernel_shape,
|
136 |
+
stride=1,
|
137 |
+
padding=0 if self.use_causal_conv else 1,
|
138 |
+
groups=in_channels * 4,
|
139 |
+
)
|
140 |
+
self.layer_norm_1 = nn.LayerNorm(
|
141 |
+
normalized_shape=in_channels, elementwise_affine=True, bias=False
|
142 |
+
)
|
143 |
+
self.conv_channels_mixer = ConvChannelsMixer(in_channels)
|
144 |
+
|
145 |
+
def forward(self, x, causal_context=None, get_causal_context=False):
|
146 |
+
to_skip = x
|
147 |
+
x = self.layer_norm(x)
|
148 |
+
new_causal_context = {}
|
149 |
+
num_extra = 0
|
150 |
+
|
151 |
+
if causal_context is not None:
|
152 |
+
name1 = self.block_name + '_causal_1'
|
153 |
+
x = torch.cat([causal_context[name1], x], dim=-2)
|
154 |
+
num_extra = causal_context[name1].shape[-2]
|
155 |
+
new_causal_context[name1] = x[..., -(self.kernel_shape - 1) :, :]
|
156 |
+
|
157 |
+
x = x.permute(0, 2, 1)
|
158 |
+
if self.use_causal_conv:
|
159 |
+
x = F.pad(x, (2, 0))
|
160 |
+
x = self.mlp1_up(x)
|
161 |
+
|
162 |
+
x = F.gelu(x, approximate='tanh')
|
163 |
+
|
164 |
+
if causal_context is not None:
|
165 |
+
x = x.permute(0, 2, 1)
|
166 |
+
name2 = self.block_name + '_causal_2'
|
167 |
+
num_extra = causal_context[name2].shape[-2]
|
168 |
+
x = torch.cat([causal_context[name2], x[..., num_extra:, :]], dim=-2)
|
169 |
+
new_causal_context[name2] = x[..., -(self.kernel_shape - 1) :, :]
|
170 |
+
x = x.permute(0, 2, 1)
|
171 |
+
|
172 |
+
if self.use_causal_conv:
|
173 |
+
x = F.pad(x, (2, 0))
|
174 |
+
x = self.mlp1_up_1(x)
|
175 |
+
x = x.permute(0, 2, 1)
|
176 |
+
|
177 |
+
if causal_context is not None:
|
178 |
+
x = x[..., num_extra:, :]
|
179 |
+
|
180 |
+
x = x[..., 0::4] + x[..., 1::4] + x[..., 2::4] + x[..., 3::4]
|
181 |
+
|
182 |
+
x = x + to_skip
|
183 |
+
to_skip = x
|
184 |
+
x = self.layer_norm_1(x)
|
185 |
+
x = self.conv_channels_mixer(x)
|
186 |
+
|
187 |
+
x = x + to_skip
|
188 |
+
return x, new_causal_context
|
189 |
+
|
190 |
+
|
191 |
+
class PIPSMLPMixer(nn.Module):
|
192 |
+
"""Depthwise-conv version of PIPs's MLP Mixer."""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
input_channels: int,
|
197 |
+
output_channels: int,
|
198 |
+
hidden_dim: int = 512,
|
199 |
+
num_blocks: int = 12,
|
200 |
+
kernel_shape: int = 3,
|
201 |
+
use_causal_conv: bool = False,
|
202 |
+
):
|
203 |
+
"""Inits Mixer module.
|
204 |
+
|
205 |
+
A depthwise-convolutional version of a MLP Mixer for processing images.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
input_channels (int): The number of input channels.
|
209 |
+
output_channels (int): The number of output channels.
|
210 |
+
hidden_dim (int, optional): The dimension of the hidden layer. Defaults
|
211 |
+
to 512.
|
212 |
+
num_blocks (int, optional): The number of convolution blocks in the
|
213 |
+
mixer. Defaults to 12.
|
214 |
+
kernel_shape (int, optional): The size of the kernel in the convolution
|
215 |
+
blocks. Defaults to 3.
|
216 |
+
use_causal_conv (bool, optional): Whether to use causal convolutions.
|
217 |
+
Defaults to False.
|
218 |
+
"""
|
219 |
+
|
220 |
+
super().__init__()
|
221 |
+
self.hidden_dim = hidden_dim
|
222 |
+
self.num_blocks = num_blocks
|
223 |
+
self.use_causal_conv = use_causal_conv
|
224 |
+
self.linear = nn.Linear(input_channels, self.hidden_dim)
|
225 |
+
self.layer_norm = nn.LayerNorm(
|
226 |
+
normalized_shape=hidden_dim, elementwise_affine=True, bias=False
|
227 |
+
)
|
228 |
+
self.linear_1 = nn.Linear(hidden_dim, output_channels)
|
229 |
+
self.blocks = nn.ModuleList([
|
230 |
+
PIPsConvBlock(
|
231 |
+
hidden_dim, kernel_shape, self.use_causal_conv, block_idx=i
|
232 |
+
)
|
233 |
+
for i in range(num_blocks)
|
234 |
+
])
|
235 |
+
|
236 |
+
def forward(self, x, causal_context=None, get_causal_context=False):
|
237 |
+
x = self.linear(x)
|
238 |
+
all_causal_context = {}
|
239 |
+
for block in self.blocks:
|
240 |
+
x, new_causal_context = block(x, causal_context, get_causal_context)
|
241 |
+
if get_causal_context:
|
242 |
+
all_causal_context.update(new_causal_context)
|
243 |
+
|
244 |
+
x = self.layer_norm(x)
|
245 |
+
x = self.linear_1(x)
|
246 |
+
return x, all_causal_context
|
247 |
+
|
248 |
+
|
249 |
+
class BlockV2(nn.Module):
|
250 |
+
"""ResNet V2 block."""
|
251 |
+
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
channels_in: int,
|
255 |
+
channels_out: int,
|
256 |
+
stride: Union[int, Sequence[int]],
|
257 |
+
use_projection: bool,
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
self.padding = (1, 1, 1, 1)
|
261 |
+
# Handle assymetric padding created by padding="SAME" in JAX/LAX
|
262 |
+
if stride == 1:
|
263 |
+
self.padding = (1, 1, 1, 1)
|
264 |
+
elif stride == 2:
|
265 |
+
self.padding = (0, 2, 0, 2)
|
266 |
+
else:
|
267 |
+
raise ValueError(
|
268 |
+
'Check correct padding using padtype_to_padsin jax._src.lax.lax'
|
269 |
+
)
|
270 |
+
|
271 |
+
self.use_projection = use_projection
|
272 |
+
if self.use_projection:
|
273 |
+
self.proj_conv = Conv2dSamePadding(
|
274 |
+
in_channels=channels_in,
|
275 |
+
out_channels=channels_out,
|
276 |
+
kernel_size=1,
|
277 |
+
stride=stride,
|
278 |
+
padding=0,
|
279 |
+
bias=False,
|
280 |
+
)
|
281 |
+
|
282 |
+
self.bn_0 = nn.InstanceNorm2d(
|
283 |
+
num_features=channels_in,
|
284 |
+
eps=1e-05,
|
285 |
+
momentum=0.1,
|
286 |
+
affine=True,
|
287 |
+
track_running_stats=False,
|
288 |
+
)
|
289 |
+
self.conv_0 = Conv2dSamePadding(
|
290 |
+
in_channels=channels_in,
|
291 |
+
out_channels=channels_out,
|
292 |
+
kernel_size=3,
|
293 |
+
stride=stride,
|
294 |
+
padding=0,
|
295 |
+
bias=False,
|
296 |
+
)
|
297 |
+
|
298 |
+
self.conv_1 = Conv2dSamePadding(
|
299 |
+
in_channels=channels_out,
|
300 |
+
out_channels=channels_out,
|
301 |
+
kernel_size=3,
|
302 |
+
stride=1,
|
303 |
+
padding=1,
|
304 |
+
bias=False,
|
305 |
+
)
|
306 |
+
self.bn_1 = nn.InstanceNorm2d(
|
307 |
+
num_features=channels_out,
|
308 |
+
eps=1e-05,
|
309 |
+
momentum=0.1,
|
310 |
+
affine=True,
|
311 |
+
track_running_stats=False,
|
312 |
+
)
|
313 |
+
|
314 |
+
def forward(self, inputs):
|
315 |
+
x = shortcut = inputs
|
316 |
+
|
317 |
+
x = self.bn_0(x)
|
318 |
+
x = torch.relu(x)
|
319 |
+
if self.use_projection:
|
320 |
+
shortcut = self.proj_conv(x)
|
321 |
+
|
322 |
+
x = self.conv_0(x)
|
323 |
+
|
324 |
+
x = self.bn_1(x)
|
325 |
+
x = torch.relu(x)
|
326 |
+
# no issues with padding here as this layer always has stride 1
|
327 |
+
x = self.conv_1(x)
|
328 |
+
|
329 |
+
return x + shortcut
|
330 |
+
|
331 |
+
|
332 |
+
class BlockGroup(nn.Module):
|
333 |
+
"""Higher level block for ResNet implementation."""
|
334 |
+
|
335 |
+
def __init__(
|
336 |
+
self,
|
337 |
+
channels_in: int,
|
338 |
+
channels_out: int,
|
339 |
+
num_blocks: int,
|
340 |
+
stride: Union[int, Sequence[int]],
|
341 |
+
use_projection: bool,
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
blocks = []
|
345 |
+
for i in range(num_blocks):
|
346 |
+
blocks.append(
|
347 |
+
BlockV2(
|
348 |
+
channels_in=channels_in if i == 0 else channels_out,
|
349 |
+
channels_out=channels_out,
|
350 |
+
stride=(1 if i else stride),
|
351 |
+
use_projection=(i == 0 and use_projection),
|
352 |
+
)
|
353 |
+
)
|
354 |
+
self.blocks = nn.ModuleList(blocks)
|
355 |
+
|
356 |
+
def forward(self, inputs):
|
357 |
+
out = inputs
|
358 |
+
for block in self.blocks:
|
359 |
+
out = block(out)
|
360 |
+
return out
|
361 |
+
|
362 |
+
|
363 |
+
class ResNet(nn.Module):
|
364 |
+
"""ResNet model."""
|
365 |
+
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
blocks_per_group: Sequence[int],
|
369 |
+
channels_per_group: Sequence[int] = (64, 128, 256, 512),
|
370 |
+
use_projection: Sequence[bool] = (True, True, True, True),
|
371 |
+
strides: Sequence[int] = (1, 2, 2, 2),
|
372 |
+
):
|
373 |
+
"""Initializes a ResNet model with customizable layers and configurations.
|
374 |
+
|
375 |
+
This constructor allows defining the architecture of a ResNet model by
|
376 |
+
setting the number of blocks, channels, projection usage, and strides for
|
377 |
+
each group of blocks within the network. It provides flexibility in
|
378 |
+
creating various ResNet configurations.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
blocks_per_group: A sequence of 4 integers, each indicating the number
|
382 |
+
of residual blocks in each group.
|
383 |
+
channels_per_group: A sequence of 4 integers, each specifying the number
|
384 |
+
of output channels for the blocks in each group. Defaults to (64, 128,
|
385 |
+
256, 512).
|
386 |
+
use_projection: A sequence of 4 booleans, each indicating whether to use
|
387 |
+
a projection shortcut (True) or an identity shortcut (False) in each
|
388 |
+
group. Defaults to (True, True, True, True).
|
389 |
+
strides: A sequence of 4 integers, each specifying the stride size for
|
390 |
+
the convolutions in each group. Defaults to (1, 2, 2, 2).
|
391 |
+
|
392 |
+
The ResNet model created will have 4 groups, with each group's
|
393 |
+
architecture defined by the corresponding elements in these sequences.
|
394 |
+
"""
|
395 |
+
super().__init__()
|
396 |
+
|
397 |
+
self.initial_conv = Conv2dSamePadding(
|
398 |
+
in_channels=3,
|
399 |
+
out_channels=channels_per_group[0],
|
400 |
+
kernel_size=(7, 7),
|
401 |
+
stride=2,
|
402 |
+
padding=0,
|
403 |
+
bias=False,
|
404 |
+
)
|
405 |
+
|
406 |
+
block_groups = []
|
407 |
+
for i, _ in enumerate(strides):
|
408 |
+
block_groups.append(
|
409 |
+
BlockGroup(
|
410 |
+
channels_in=channels_per_group[i - 1] if i > 0 else 64,
|
411 |
+
channels_out=channels_per_group[i],
|
412 |
+
num_blocks=blocks_per_group[i],
|
413 |
+
stride=strides[i],
|
414 |
+
use_projection=use_projection[i],
|
415 |
+
)
|
416 |
+
)
|
417 |
+
self.block_groups = nn.ModuleList(block_groups)
|
418 |
+
|
419 |
+
def forward(self, inputs):
|
420 |
+
result = {}
|
421 |
+
out = inputs
|
422 |
+
out = self.initial_conv(out)
|
423 |
+
result['initial_conv'] = out
|
424 |
+
|
425 |
+
for block_id, block_group in enumerate(self.block_groups):
|
426 |
+
out = block_group(out)
|
427 |
+
result[f'resnet_unit_{block_id}'] = out
|
428 |
+
|
429 |
+
return result
|
locotrack_pytorch/models/utils.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Pytorch model utilities."""
|
17 |
+
import math
|
18 |
+
from typing import Any, Sequence, Union
|
19 |
+
from einshape.src import abstract_ops
|
20 |
+
from einshape.src import backend
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
|
26 |
+
def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
|
27 |
+
"""Resizes a 5D tensor using bilinear interpolation.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
x: A 5D tensor of shape (B, T, W, H, C) where B is batch size, T is
|
31 |
+
time, W is width, H is height, and C is the number of channels.
|
32 |
+
resolution: The target resolution as a tuple (new_width, new_height).
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
The resized tensor.
|
36 |
+
"""
|
37 |
+
b, t, h, w, c = x.size()
|
38 |
+
x = x.permute(0, 1, 4, 2, 3).reshape(b, t * c, h, w)
|
39 |
+
x = F.interpolate(x, size=resolution, mode='bilinear', align_corners=False)
|
40 |
+
b, _, h, w = x.size()
|
41 |
+
x = x.reshape(b, t, c, h, w).permute(0, 1, 3, 4, 2)
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
def map_coordinates_3d(
|
46 |
+
feats: torch.Tensor, coordinates: torch.Tensor
|
47 |
+
) -> torch.Tensor:
|
48 |
+
"""Maps 3D coordinates to corresponding features using bilinear interpolation.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
feats: A 5D tensor of features with shape (B, W, H, D, C), where B is batch
|
52 |
+
size, W is width, H is height, D is depth, and C is the number of
|
53 |
+
channels.
|
54 |
+
coordinates: A 3D tensor of coordinates with shape (B, N, 3), where N is the
|
55 |
+
number of coordinates and the last dimension represents (W, H, D)
|
56 |
+
coordinates.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
The mapped features tensor.
|
60 |
+
"""
|
61 |
+
x = feats.permute(0, 4, 1, 2, 3)
|
62 |
+
y = coordinates[:, :, None, None, :].float().clone()
|
63 |
+
y[..., 0] = y[..., 0] + 0.5
|
64 |
+
y = 2 * (y / torch.tensor(x.shape[2:], device=y.device)) - 1
|
65 |
+
y = torch.flip(y, dims=(-1,))
|
66 |
+
out = (
|
67 |
+
F.grid_sample(
|
68 |
+
x, y, mode='bilinear', align_corners=False, padding_mode='border'
|
69 |
+
)
|
70 |
+
.squeeze(dim=(3, 4))
|
71 |
+
.permute(0, 2, 1)
|
72 |
+
)
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
def map_coordinates_2d(
|
77 |
+
feats: torch.Tensor, coordinates: torch.Tensor
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""Maps 2D coordinates to feature maps using bilinear interpolation.
|
80 |
+
|
81 |
+
The function performs bilinear interpolation on the feature maps (`feats`)
|
82 |
+
at the specified `coordinates`. The coordinates are normalized between
|
83 |
+
-1 and 1 The result is a tensor of sampled features corresponding
|
84 |
+
to these coordinates.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
feats (Tensor): A 5D tensor of shape (N, T, H, W, C) representing feature
|
88 |
+
maps, where N is the batch size, T is the number of frames, H and W are
|
89 |
+
height and width, and C is the number of channels.
|
90 |
+
coordinates (Tensor): A 5D tensor of shape (N, P, T, S, XY) representing
|
91 |
+
coordinates, where N is the batch size, P is the number of points, T is
|
92 |
+
the number of frames, S is the number of samples, and XY represents the 2D
|
93 |
+
coordinates.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Tensor: A 5D tensor of the sampled features corresponding to the
|
97 |
+
given coordinates, of shape (N, P, T, S, C).
|
98 |
+
"""
|
99 |
+
n, t, h, w, c = feats.shape
|
100 |
+
x = feats.permute(0, 1, 4, 2, 3).view(n * t, c, h, w)
|
101 |
+
|
102 |
+
n, p, t, s, xy = coordinates.shape
|
103 |
+
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
|
104 |
+
y = 2 * (y / h) - 1
|
105 |
+
y = torch.flip(y, dims=(-1,)).float()
|
106 |
+
|
107 |
+
out = F.grid_sample(
|
108 |
+
x, y, mode='bilinear', align_corners=False, padding_mode='zeros'
|
109 |
+
)
|
110 |
+
_, c, _, _ = out.shape
|
111 |
+
out = out.permute(0, 2, 3, 1).view(n, t, p, s, c).permute(0, 2, 1, 3, 4)
|
112 |
+
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def soft_argmax_heatmap_batched(softmax_val, threshold=5):
|
117 |
+
"""Test if two image resolutions are the same."""
|
118 |
+
b, h, w, d1, d2 = softmax_val.shape
|
119 |
+
y, x = torch.meshgrid(
|
120 |
+
torch.arange(d1, device=softmax_val.device),
|
121 |
+
torch.arange(d2, device=softmax_val.device),
|
122 |
+
indexing='ij',
|
123 |
+
)
|
124 |
+
coords = torch.stack([x + 0.5, y + 0.5], dim=-1).to(softmax_val.device)
|
125 |
+
softmax_val_flat = softmax_val.reshape(b, h, w, -1)
|
126 |
+
argmax_pos = torch.argmax(softmax_val_flat, dim=-1)
|
127 |
+
|
128 |
+
pos = coords.reshape(-1, 2)[argmax_pos]
|
129 |
+
valid = (
|
130 |
+
torch.sum(
|
131 |
+
torch.square(
|
132 |
+
coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :]
|
133 |
+
),
|
134 |
+
dim=-1,
|
135 |
+
keepdims=True,
|
136 |
+
)
|
137 |
+
< threshold**2
|
138 |
+
)
|
139 |
+
|
140 |
+
weighted_sum = torch.sum(
|
141 |
+
coords[None, None, None, :, :, :]
|
142 |
+
* valid
|
143 |
+
* softmax_val[:, :, :, :, :, None],
|
144 |
+
dim=(3, 4),
|
145 |
+
)
|
146 |
+
sum_of_weights = torch.maximum(
|
147 |
+
torch.sum(valid * softmax_val[:, :, :, :, :, None], dim=(3, 4)),
|
148 |
+
torch.tensor(1e-12, device=softmax_val.device),
|
149 |
+
)
|
150 |
+
return weighted_sum / sum_of_weights
|
151 |
+
|
152 |
+
|
153 |
+
def heatmaps_to_points(
|
154 |
+
all_pairs_softmax,
|
155 |
+
image_shape,
|
156 |
+
threshold=5,
|
157 |
+
query_points=None,
|
158 |
+
):
|
159 |
+
"""Convert heatmaps to points using soft argmax."""
|
160 |
+
|
161 |
+
out_points = soft_argmax_heatmap_batched(all_pairs_softmax, threshold)
|
162 |
+
feature_grid_shape = all_pairs_softmax.shape[1:]
|
163 |
+
# Note: out_points is now [x, y]; we need to divide by [width, height].
|
164 |
+
# image_shape[3] is width and image_shape[2] is height.
|
165 |
+
out_points = convert_grid_coordinates(
|
166 |
+
out_points,
|
167 |
+
feature_grid_shape[3:1:-1],
|
168 |
+
image_shape[3:1:-1],
|
169 |
+
)
|
170 |
+
assert feature_grid_shape[1] == image_shape[1]
|
171 |
+
if query_points is not None:
|
172 |
+
# The [..., 0:1] is because we only care about the frame index.
|
173 |
+
query_frame = convert_grid_coordinates(
|
174 |
+
query_points.detach(),
|
175 |
+
image_shape[1:4],
|
176 |
+
feature_grid_shape[1:4],
|
177 |
+
coordinate_format='tyx',
|
178 |
+
)[..., 0:1]
|
179 |
+
|
180 |
+
query_frame = torch.round(query_frame)
|
181 |
+
frame_indices = torch.arange(image_shape[1], device=query_frame.device)[
|
182 |
+
None, None, :
|
183 |
+
]
|
184 |
+
is_query_point = query_frame == frame_indices
|
185 |
+
|
186 |
+
is_query_point = is_query_point[:, :, :, None]
|
187 |
+
out_points = (
|
188 |
+
out_points * ~is_query_point
|
189 |
+
+ torch.flip(query_points[:, :, None], dims=(-1,))[..., 0:2]
|
190 |
+
* is_query_point
|
191 |
+
)
|
192 |
+
|
193 |
+
return out_points
|
194 |
+
|
195 |
+
|
196 |
+
def is_same_res(r1, r2):
|
197 |
+
"""Test if two image resolutions are the same."""
|
198 |
+
return all([x == y for x, y in zip(r1, r2)])
|
199 |
+
|
200 |
+
|
201 |
+
def convert_grid_coordinates(
|
202 |
+
coords: torch.Tensor,
|
203 |
+
input_grid_size: Sequence[int],
|
204 |
+
output_grid_size: Sequence[int],
|
205 |
+
coordinate_format: str = 'xy',
|
206 |
+
) -> torch.Tensor:
|
207 |
+
"""Convert grid coordinates to correct format."""
|
208 |
+
if isinstance(input_grid_size, tuple):
|
209 |
+
input_grid_size = torch.tensor(input_grid_size, device=coords.device)
|
210 |
+
if isinstance(output_grid_size, tuple):
|
211 |
+
output_grid_size = torch.tensor(output_grid_size, device=coords.device)
|
212 |
+
|
213 |
+
if coordinate_format == 'xy':
|
214 |
+
if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
|
215 |
+
raise ValueError(
|
216 |
+
'If coordinate_format is xy, the shapes must be length 2.'
|
217 |
+
)
|
218 |
+
elif coordinate_format == 'tyx':
|
219 |
+
if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
|
220 |
+
raise ValueError(
|
221 |
+
'If coordinate_format is tyx, the shapes must be length 3.'
|
222 |
+
)
|
223 |
+
if input_grid_size[0] != output_grid_size[0]:
|
224 |
+
raise ValueError('converting frame count is not supported.')
|
225 |
+
else:
|
226 |
+
raise ValueError('Recognized coordinate formats are xy and tyx.')
|
227 |
+
|
228 |
+
position_in_grid = coords
|
229 |
+
position_in_grid = position_in_grid * output_grid_size / input_grid_size
|
230 |
+
|
231 |
+
return position_in_grid
|
232 |
+
|
233 |
+
|
234 |
+
class _JaxBackend(backend.Backend[torch.Tensor]):
|
235 |
+
"""Einshape implementation for PyTorch."""
|
236 |
+
|
237 |
+
# https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py
|
238 |
+
|
239 |
+
def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor:
|
240 |
+
return x.reshape(op.shape)
|
241 |
+
|
242 |
+
def transpose(
|
243 |
+
self, x: torch.Tensor, op: abstract_ops.Transpose
|
244 |
+
) -> torch.Tensor:
|
245 |
+
return x.permute(op.perm)
|
246 |
+
|
247 |
+
def broadcast(
|
248 |
+
self, x: torch.Tensor, op: abstract_ops.Broadcast
|
249 |
+
) -> torch.Tensor:
|
250 |
+
shape = op.transform_shape(x.shape)
|
251 |
+
for axis_position in sorted(op.axis_sizes.keys()):
|
252 |
+
x = x.unsqueeze(axis_position)
|
253 |
+
return x.expand(shape)
|
254 |
+
|
255 |
+
|
256 |
+
def einshape(
|
257 |
+
equation: str, value: Union[torch.Tensor, Any], **index_sizes: int
|
258 |
+
) -> torch.Tensor:
|
259 |
+
"""Reshapes `value` according to the given Shape Equation.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
equation: The Shape Equation specifying the index regrouping and reordering.
|
263 |
+
value: Input tensor, or tensor-like object.
|
264 |
+
**index_sizes: Sizes of indices, where they cannot be inferred from
|
265 |
+
`input_shape`.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
Tensor derived from `value` by reshaping as specified by `equation`.
|
269 |
+
"""
|
270 |
+
if not isinstance(value, torch.Tensor):
|
271 |
+
value = torch.tensor(value)
|
272 |
+
return _JaxBackend().exec(equation, value, value.shape, **index_sizes)
|
273 |
+
|
274 |
+
|
275 |
+
def generate_default_resolutions(full_size, train_size, num_levels=None):
|
276 |
+
"""Generate a list of logarithmically-spaced resolutions.
|
277 |
+
|
278 |
+
Generated resolutions are between train_size and full_size, inclusive, with
|
279 |
+
num_levels different resolutions total. Useful for generating the input to
|
280 |
+
refinement_resolutions in PIPs.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
full_size: 2-tuple of ints. The full image size desired.
|
284 |
+
train_size: 2-tuple of ints. The smallest refinement level. Should
|
285 |
+
typically match the training resolution, which is (256, 256) for TAPIR.
|
286 |
+
num_levels: number of levels. Typically each resolution should be less than
|
287 |
+
twice the size of prior resolutions.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
A list of resolutions.
|
291 |
+
"""
|
292 |
+
if all([x == y for x, y in zip(train_size, full_size)]):
|
293 |
+
return [train_size]
|
294 |
+
|
295 |
+
if num_levels is None:
|
296 |
+
size_ratio = np.array(full_size) / np.array(train_size)
|
297 |
+
num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1)
|
298 |
+
|
299 |
+
if num_levels <= 1:
|
300 |
+
return [train_size]
|
301 |
+
|
302 |
+
h, w = full_size[0:2]
|
303 |
+
if h % 8 != 0 or w % 8 != 0:
|
304 |
+
print(
|
305 |
+
'Warning: output size is not a multiple of 8. Final layer '
|
306 |
+
+ 'will round size down.'
|
307 |
+
)
|
308 |
+
ll_h, ll_w = train_size[0:2]
|
309 |
+
|
310 |
+
sizes = []
|
311 |
+
for i in range(num_levels):
|
312 |
+
size = (
|
313 |
+
int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8,
|
314 |
+
int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8,
|
315 |
+
)
|
316 |
+
sizes.append(size)
|
317 |
+
return sizes
|
318 |
+
|
319 |
+
|
320 |
+
class Conv2dSamePadding(torch.nn.Conv2d):
|
321 |
+
|
322 |
+
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
|
323 |
+
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
324 |
+
|
325 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
326 |
+
ih, iw = x.size()[-2:]
|
327 |
+
|
328 |
+
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
|
329 |
+
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
|
330 |
+
|
331 |
+
if pad_h > 0 or pad_w > 0:
|
332 |
+
x = F.pad(
|
333 |
+
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
334 |
+
)
|
335 |
+
return F.conv2d(
|
336 |
+
x,
|
337 |
+
self.weight,
|
338 |
+
self.bias,
|
339 |
+
self.stride,
|
340 |
+
# self.padding,
|
341 |
+
0,
|
342 |
+
self.dilation,
|
343 |
+
self.groups,
|
344 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einshape==1.0
|
2 |
+
gradio==4.40.0
|
3 |
+
mediapy==1.2.2
|
4 |
+
opencv-python==4.10.0.84
|
5 |
+
torch==2.4.0
|
6 |
+
torchaudio==2.4.0
|
7 |
+
torchvision==0.19.0
|
viz_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Visualization utility functions."""
|
2 |
+
|
3 |
+
import colorsys
|
4 |
+
import random
|
5 |
+
from typing import List, Optional, Sequence, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
# Generate random colormaps for visualizing different points.
|
11 |
+
def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
|
12 |
+
"""Gets colormap for points."""
|
13 |
+
colors = []
|
14 |
+
for i in np.arange(0.0, 360.0, 360.0 / num_colors):
|
15 |
+
hue = i / 360.0
|
16 |
+
lightness = (50 + np.random.rand() * 10) / 100.0
|
17 |
+
saturation = (90 + np.random.rand() * 10) / 100.0
|
18 |
+
color = colorsys.hls_to_rgb(hue, lightness, saturation)
|
19 |
+
colors.append(
|
20 |
+
(int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
|
21 |
+
)
|
22 |
+
random.shuffle(colors)
|
23 |
+
return colors
|
24 |
+
|
25 |
+
|
26 |
+
def paint_point_track(
|
27 |
+
frames: np.ndarray,
|
28 |
+
point_tracks: np.ndarray,
|
29 |
+
visibles: np.ndarray,
|
30 |
+
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
31 |
+
) -> np.ndarray:
|
32 |
+
"""Converts a sequence of points to color code video.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
frames: [num_frames, height, width, 3], np.uint8, [0, 255]
|
36 |
+
point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
|
37 |
+
visibles: [num_points, num_frames], bool
|
38 |
+
colormap: colormap for points, each point has a different RGB color.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
video: [num_frames, height, width, 3], np.uint8, [0, 255]
|
42 |
+
"""
|
43 |
+
num_points, num_frames = point_tracks.shape[0:2]
|
44 |
+
if colormap is None:
|
45 |
+
colormap = get_colors(num_colors=num_points)
|
46 |
+
height, width = frames.shape[1:3]
|
47 |
+
dot_size_as_fraction_of_min_edge = 0.015
|
48 |
+
radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
|
49 |
+
diam = radius * 2 + 1
|
50 |
+
quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
|
51 |
+
quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
|
52 |
+
icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
|
53 |
+
sharpness = 0.15
|
54 |
+
icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
|
55 |
+
icon = 1 - icon[:, :, np.newaxis]
|
56 |
+
icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
|
57 |
+
icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
|
58 |
+
icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
|
59 |
+
icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
|
60 |
+
|
61 |
+
video = frames.copy()
|
62 |
+
for t in range(num_frames):
|
63 |
+
# Pad so that points that extend outside the image frame don't crash us
|
64 |
+
image = np.pad(
|
65 |
+
video[t],
|
66 |
+
[
|
67 |
+
(radius + 1, radius + 1),
|
68 |
+
(radius + 1, radius + 1),
|
69 |
+
(0, 0),
|
70 |
+
],
|
71 |
+
)
|
72 |
+
for i in range(num_points):
|
73 |
+
# The icon is centered at the center of a pixel, but the input coordinates
|
74 |
+
# are raster coordinates. Therefore, to render a point at (1,1) (which
|
75 |
+
# lies on the corner between four pixels), we need 1/4 of the icon placed
|
76 |
+
# centered on the 0'th row, 0'th column, etc. We need to subtract
|
77 |
+
# 0.5 to make the fractional position come out right.
|
78 |
+
x, y = point_tracks[i, t, :] + 0.5
|
79 |
+
x = min(max(x, 0.0), width)
|
80 |
+
y = min(max(y, 0.0), height)
|
81 |
+
|
82 |
+
if visibles[i, t]:
|
83 |
+
x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
|
84 |
+
x2, y2 = x1 + 1, y1 + 1
|
85 |
+
|
86 |
+
# bilinear interpolation
|
87 |
+
patch = (
|
88 |
+
icon1 * (x2 - x) * (y2 - y)
|
89 |
+
+ icon2 * (x2 - x) * (y - y1)
|
90 |
+
+ icon3 * (x - x1) * (y2 - y)
|
91 |
+
+ icon4 * (x - x1) * (y - y1)
|
92 |
+
)
|
93 |
+
x_ub = x1 + 2 * radius + 2
|
94 |
+
y_ub = y1 + 2 * radius + 2
|
95 |
+
image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
|
96 |
+
y1:y_ub, x1:x_ub, :
|
97 |
+
] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
|
98 |
+
|
99 |
+
# Remove the pad
|
100 |
+
video[t] = image[
|
101 |
+
radius + 1 : -radius - 1, radius + 1 : -radius - 1
|
102 |
+
].astype(np.uint8)
|
103 |
+
return video
|
104 |
+
|