Spaces:
Running
on
Zero
Running
on
Zero
added basics
Browse files- nets/alltracker.py +588 -0
- nets/blocks.py +1304 -0
- utils/basic.py +144 -0
- utils/data.py +96 -0
- utils/improc.py +1103 -0
- utils/loss.py +220 -0
- utils/misc.py +100 -0
- utils/py.py +755 -0
- utils/samp.py +213 -0
- utils/saveload.py +65 -0
nets/alltracker.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import utils.misc
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from nets.blocks import CNBlockConfig, ConvNeXt, conv1x1, RelUpdateBlock, InputPadder, CorrBlock, BasicEncoder
|
| 8 |
+
|
| 9 |
+
class Net(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
seqlen,
|
| 13 |
+
use_attn=True,
|
| 14 |
+
use_mixer=False,
|
| 15 |
+
use_conv=False,
|
| 16 |
+
use_convb=False,
|
| 17 |
+
use_basicencoder=False,
|
| 18 |
+
use_sinmotion=False,
|
| 19 |
+
use_relmotion=False,
|
| 20 |
+
use_sinrelmotion=False,
|
| 21 |
+
use_feats8=False,
|
| 22 |
+
no_time=False,
|
| 23 |
+
no_space=False,
|
| 24 |
+
no_split=False,
|
| 25 |
+
no_ctx=False,
|
| 26 |
+
full_split=False,
|
| 27 |
+
corr_levels=5,
|
| 28 |
+
corr_radius=4,
|
| 29 |
+
num_blocks=3,
|
| 30 |
+
dim=128,
|
| 31 |
+
hdim=128,
|
| 32 |
+
init_weights=True,
|
| 33 |
+
):
|
| 34 |
+
super(Net, self).__init__()
|
| 35 |
+
|
| 36 |
+
self.dim = dim
|
| 37 |
+
self.hdim = hdim
|
| 38 |
+
|
| 39 |
+
self.no_time = no_time
|
| 40 |
+
self.no_space = no_space
|
| 41 |
+
self.seqlen = seqlen
|
| 42 |
+
self.corr_levels = corr_levels
|
| 43 |
+
self.corr_radius = corr_radius
|
| 44 |
+
self.corr_channel = self.corr_levels * (self.corr_radius * 2 + 1) ** 2
|
| 45 |
+
self.num_blocks = num_blocks
|
| 46 |
+
|
| 47 |
+
self.use_feats8 = use_feats8
|
| 48 |
+
self.use_basicencoder = use_basicencoder
|
| 49 |
+
self.use_sinmotion = use_sinmotion
|
| 50 |
+
self.use_relmotion = use_relmotion
|
| 51 |
+
self.use_sinrelmotion = use_sinrelmotion
|
| 52 |
+
self.no_split = no_split
|
| 53 |
+
self.no_ctx = no_ctx
|
| 54 |
+
self.full_split = full_split
|
| 55 |
+
|
| 56 |
+
if use_basicencoder:
|
| 57 |
+
if self.full_split:
|
| 58 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
| 59 |
+
self.cnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
| 60 |
+
else:
|
| 61 |
+
if self.no_split:
|
| 62 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim, stride=8)
|
| 63 |
+
else:
|
| 64 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.dim*2, stride=8)
|
| 65 |
+
else:
|
| 66 |
+
block_setting = [
|
| 67 |
+
CNBlockConfig(96, 192, 3, True), # 4x
|
| 68 |
+
CNBlockConfig(192, 384, 3, False), # 8x
|
| 69 |
+
CNBlockConfig(384, None, 9, False), # 8x
|
| 70 |
+
]
|
| 71 |
+
self.cnn = ConvNeXt(block_setting, stochastic_depth_prob=0.0, init_weights=init_weights)
|
| 72 |
+
if self.no_split:
|
| 73 |
+
self.dot_conv = conv1x1(384, dim)
|
| 74 |
+
else:
|
| 75 |
+
self.dot_conv = conv1x1(384, dim*2)
|
| 76 |
+
|
| 77 |
+
self.upsample_weight = nn.Sequential(
|
| 78 |
+
# convex combination of 3x3 patches
|
| 79 |
+
nn.Conv2d(dim, dim * 2, 3, padding=1),
|
| 80 |
+
nn.ReLU(inplace=True),
|
| 81 |
+
nn.Conv2d(dim * 2, 64 * 9, 1, padding=0)
|
| 82 |
+
)
|
| 83 |
+
self.flow_head = nn.Sequential(
|
| 84 |
+
nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
|
| 85 |
+
nn.ReLU(inplace=True),
|
| 86 |
+
nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
|
| 87 |
+
)
|
| 88 |
+
self.visconf_head = nn.Sequential(
|
| 89 |
+
nn.Conv2d(dim, 2*dim, kernel_size=3, padding=1),
|
| 90 |
+
nn.ReLU(inplace=True),
|
| 91 |
+
nn.Conv2d(2*dim, 2, kernel_size=3, padding=1)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if self.use_sinrelmotion:
|
| 95 |
+
self.pdim = 84 # 32*2
|
| 96 |
+
elif self.use_relmotion:
|
| 97 |
+
self.pdim = 4
|
| 98 |
+
elif self.use_sinmotion:
|
| 99 |
+
self.pdim = 42
|
| 100 |
+
else:
|
| 101 |
+
self.pdim = 2
|
| 102 |
+
|
| 103 |
+
self.update_block = RelUpdateBlock(self.corr_channel, self.num_blocks, cdim=dim, hdim=hdim, pdim=self.pdim,
|
| 104 |
+
use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb,
|
| 105 |
+
use_layer_scale=True, no_time=no_time, no_space=no_space,
|
| 106 |
+
no_ctx=no_ctx)
|
| 107 |
+
|
| 108 |
+
time_line = torch.linspace(0, seqlen-1, seqlen).reshape(1, seqlen, 1)
|
| 109 |
+
self.register_buffer("time_emb", utils.misc.get_1d_sincos_pos_embed_from_grid(self.dim, time_line[0])) # 1,S,C
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def fetch_time_embed(self, t, dtype, is_training=False):
|
| 113 |
+
S = self.time_emb.shape[1]
|
| 114 |
+
if t == S:
|
| 115 |
+
return self.time_emb.to(dtype)
|
| 116 |
+
elif t==1:
|
| 117 |
+
if is_training:
|
| 118 |
+
ind = np.random.choice(S)
|
| 119 |
+
return self.time_emb[:,ind:ind+1].to(dtype)
|
| 120 |
+
else:
|
| 121 |
+
return self.time_emb[:,1:2].to(dtype)
|
| 122 |
+
else:
|
| 123 |
+
time_emb = self.time_emb.float()
|
| 124 |
+
time_emb = F.interpolate(time_emb.permute(0, 2, 1), size=t, mode="linear").permute(0, 2, 1)
|
| 125 |
+
return time_emb.to(dtype)
|
| 126 |
+
|
| 127 |
+
def coords_grid(self, batch, ht, wd, device, dtype):
|
| 128 |
+
coords = torch.meshgrid(torch.arange(ht, device=device, dtype=dtype), torch.arange(wd, device=device, dtype=dtype), indexing='ij')
|
| 129 |
+
coords = torch.stack(coords[::-1], dim=0)
|
| 130 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 131 |
+
|
| 132 |
+
def initialize_flow(self, img):
|
| 133 |
+
""" Flow is represented as difference between two coordinate grids flow = coords2 - coords1"""
|
| 134 |
+
N, C, H, W = img.shape
|
| 135 |
+
coords1 = self.coords_grid(N, H//8, W//8, device=img.device)
|
| 136 |
+
coords2 = self.coords_grid(N, H//8, W//8, device=img.device)
|
| 137 |
+
return coords1, coords2
|
| 138 |
+
|
| 139 |
+
def upsample_data(self, flow, mask):
|
| 140 |
+
""" Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """
|
| 141 |
+
N, C, H, W = flow.shape
|
| 142 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
| 143 |
+
mask = torch.softmax(mask, dim=2)
|
| 144 |
+
|
| 145 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
| 146 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
| 147 |
+
|
| 148 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
| 149 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
| 150 |
+
|
| 151 |
+
return up_flow.reshape(N, 2, 8*H, 8*W).to(flow.dtype)
|
| 152 |
+
|
| 153 |
+
def get_T_padded_images(self, images, T, S, is_training, stride=None, pad=True):
|
| 154 |
+
B,T,C,H,W = images.shape
|
| 155 |
+
indices = None
|
| 156 |
+
if T > 2:
|
| 157 |
+
step = S // 2 if stride is None else stride
|
| 158 |
+
indices = []
|
| 159 |
+
start = 0
|
| 160 |
+
while start + S < T:
|
| 161 |
+
indices.append(start)
|
| 162 |
+
start += step
|
| 163 |
+
indices.append(start)
|
| 164 |
+
Tpad = indices[-1]+S-T
|
| 165 |
+
if pad:
|
| 166 |
+
if is_training:
|
| 167 |
+
assert Tpad == 0
|
| 168 |
+
else:
|
| 169 |
+
images = images.reshape(B,1,T,C*H*W)
|
| 170 |
+
if Tpad > 0:
|
| 171 |
+
padding_tensor = images[:,:,-1:,:].expand(B,1,Tpad,C*H*W)
|
| 172 |
+
images = torch.cat([images, padding_tensor], dim=2)
|
| 173 |
+
images = images.reshape(B,T+Tpad,C,H,W)
|
| 174 |
+
T = T+Tpad
|
| 175 |
+
else:
|
| 176 |
+
assert T == 2
|
| 177 |
+
return images, T, indices
|
| 178 |
+
|
| 179 |
+
def get_fmaps(self, images_, B, T, sw, is_training):
|
| 180 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
| 181 |
+
|
| 182 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
| 183 |
+
if self.no_split:
|
| 184 |
+
C = self.dim
|
| 185 |
+
|
| 186 |
+
fmaps_chunk_size = 32
|
| 187 |
+
if (not is_training) and (T > fmaps_chunk_size):
|
| 188 |
+
images = images_.reshape(B,T,3,H_pad,W_pad)
|
| 189 |
+
fmaps = []
|
| 190 |
+
for t in range(0, T, fmaps_chunk_size):
|
| 191 |
+
images_chunk = images[:, t : t + fmaps_chunk_size]
|
| 192 |
+
images_chunk = images_chunk.cuda()
|
| 193 |
+
if self.use_basicencoder:
|
| 194 |
+
if self.full_split:
|
| 195 |
+
fmaps_chunk1 = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
| 196 |
+
fmaps_chunk2 = self.cnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
| 197 |
+
fmaps_chunk = torch.cat([fmaps_chunk1, fmaps_chunk2], axis=1)
|
| 198 |
+
else:
|
| 199 |
+
fmaps_chunk = self.fnet(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
| 200 |
+
else:
|
| 201 |
+
fmaps_chunk = self.cnn(images_chunk.reshape(-1, 3, H_pad, W_pad))
|
| 202 |
+
if t==0 and sw is not None and sw.save_this:
|
| 203 |
+
sw.summ_feat('1_model/fmap_raw', fmaps_chunk[0:1])
|
| 204 |
+
fmaps_chunk = self.dot_conv(fmaps_chunk) # B*T,C,H8,W8
|
| 205 |
+
T_chunk = images_chunk.shape[1]
|
| 206 |
+
fmaps.append(fmaps_chunk.reshape(B, -1, C, H8, W8))
|
| 207 |
+
fmaps_ = torch.cat(fmaps, dim=1).reshape(-1, C, H8, W8)
|
| 208 |
+
else:
|
| 209 |
+
if not is_training:
|
| 210 |
+
# sometimes we need to move things to cuda here
|
| 211 |
+
images_ = images_.cuda()
|
| 212 |
+
if self.use_basicencoder:
|
| 213 |
+
if self.full_split:
|
| 214 |
+
fmaps1_ = self.fnet(images_)
|
| 215 |
+
fmaps2_ = self.cnet(images_)
|
| 216 |
+
fmaps_ = torch.cat([fmaps1_, fmaps2_], axis=1)
|
| 217 |
+
else:
|
| 218 |
+
fmaps_ = self.fnet(images_)
|
| 219 |
+
else:
|
| 220 |
+
fmaps_ = self.cnn(images_)
|
| 221 |
+
if sw is not None and sw.save_this:
|
| 222 |
+
sw.summ_feat('1_model/fmap_raw', fmaps_[0:1])
|
| 223 |
+
fmaps_ = self.dot_conv(fmaps_) # B*T,C,H8,W8
|
| 224 |
+
return fmaps_
|
| 225 |
+
|
| 226 |
+
def forward(self, images, iters=4, sw=None, is_training=False, stride=None):
|
| 227 |
+
B,T,C,H,W = images.shape
|
| 228 |
+
S = self.seqlen
|
| 229 |
+
device = images.device
|
| 230 |
+
dtype = images.dtype
|
| 231 |
+
|
| 232 |
+
print('images', images.shape)
|
| 233 |
+
|
| 234 |
+
# images are in [0,255]
|
| 235 |
+
mean = torch.as_tensor([0.485, 0.456, 0.406], device=device).reshape(1,1,3,1,1).to(images.dtype)
|
| 236 |
+
std = torch.as_tensor([0.229, 0.224, 0.225], device=device).reshape(1,1,3,1,1).to(images.dtype)
|
| 237 |
+
images = images / 255.0
|
| 238 |
+
images = (images - mean)/std
|
| 239 |
+
print("a0 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 240 |
+
|
| 241 |
+
T_bak = T
|
| 242 |
+
if stride is not None:
|
| 243 |
+
pad = False
|
| 244 |
+
else:
|
| 245 |
+
pad = True
|
| 246 |
+
images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride=stride, pad=pad)
|
| 247 |
+
|
| 248 |
+
images = images.contiguous()
|
| 249 |
+
images_ = images.reshape(B*T,3,H,W)
|
| 250 |
+
padder = InputPadder(images_.shape)
|
| 251 |
+
images_ = padder.pad(images_)[0]
|
| 252 |
+
|
| 253 |
+
print("a1 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 254 |
+
|
| 255 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
| 256 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
| 257 |
+
C2 = C//2
|
| 258 |
+
if self.no_split:
|
| 259 |
+
C = self.dim
|
| 260 |
+
C2 = C
|
| 261 |
+
|
| 262 |
+
fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
|
| 263 |
+
device = fmaps.device
|
| 264 |
+
print("a2 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 265 |
+
|
| 266 |
+
fmap_anchor = fmaps[:,0]
|
| 267 |
+
|
| 268 |
+
if T<=2 or is_training:
|
| 269 |
+
# note: collecting preds can get expensive on a long video
|
| 270 |
+
all_flow_preds = []
|
| 271 |
+
all_visconf_preds = []
|
| 272 |
+
else:
|
| 273 |
+
all_flow_preds = None
|
| 274 |
+
all_visconf_preds = None
|
| 275 |
+
|
| 276 |
+
if T > 2: # multiframe tracking
|
| 277 |
+
|
| 278 |
+
# we will store our final outputs in these tensors
|
| 279 |
+
full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
|
| 280 |
+
full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device=device)
|
| 281 |
+
# 1/8 resolution
|
| 282 |
+
full_flows8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 283 |
+
full_visconfs8 = torch.zeros((B,T,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 284 |
+
|
| 285 |
+
if self.use_feats8:
|
| 286 |
+
full_feats8 = torch.zeros((B,T,C2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 287 |
+
visits = np.zeros((T))
|
| 288 |
+
print("a3 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 289 |
+
|
| 290 |
+
for ii, ind in enumerate(indices):
|
| 291 |
+
ara = np.arange(ind,ind+S)
|
| 292 |
+
print('ara', ara)
|
| 293 |
+
if ii < len(indices)-1:
|
| 294 |
+
next_ind = indices[ii+1]
|
| 295 |
+
next_ara = np.arange(next_ind,next_ind+S)
|
| 296 |
+
|
| 297 |
+
# print("torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024), 'ara', ara)
|
| 298 |
+
fmaps2 = fmaps[:,ara]
|
| 299 |
+
flows8 = full_flows8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
|
| 300 |
+
visconfs8 = full_visconfs8[:,ara].reshape(B*(S),2,H_pad//8,W_pad//8).detach()
|
| 301 |
+
|
| 302 |
+
if self.use_feats8:
|
| 303 |
+
if ind==0:
|
| 304 |
+
feats8 = None
|
| 305 |
+
else:
|
| 306 |
+
feats8 = full_feats8[:,ara].reshape(B*(S),C2,H_pad//8,W_pad//8).detach()
|
| 307 |
+
else:
|
| 308 |
+
feats8 = None
|
| 309 |
+
print("a4 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 310 |
+
|
| 311 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
| 312 |
+
fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=feats8, flows8=flows8,
|
| 313 |
+
is_training=is_training)
|
| 314 |
+
print("a5 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 315 |
+
|
| 316 |
+
unpad_flow_predictions = []
|
| 317 |
+
unpad_visconf_predictions = []
|
| 318 |
+
for i in range(len(flow_predictions)):
|
| 319 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
| 320 |
+
unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
|
| 321 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
| 322 |
+
unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
|
| 323 |
+
print("a6 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 324 |
+
|
| 325 |
+
full_flows[:,ara] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)
|
| 326 |
+
full_flows8[:,ara] = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
|
| 327 |
+
full_visconfs[:,ara] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)
|
| 328 |
+
full_visconfs8[:,ara] = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
|
| 329 |
+
if self.use_feats8:
|
| 330 |
+
full_feats8[:,ara] = feats8.reshape(B,S,C2,H_pad//8,W_pad//8)
|
| 331 |
+
visits[ara] += 1
|
| 332 |
+
print("a7 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 333 |
+
|
| 334 |
+
if is_training:
|
| 335 |
+
all_flow_preds.append(unpad_flow_predictions)
|
| 336 |
+
all_visconf_preds.append(unpad_visconf_predictions)
|
| 337 |
+
else:
|
| 338 |
+
del unpad_flow_predictions
|
| 339 |
+
del unpad_visconf_predictions
|
| 340 |
+
|
| 341 |
+
# for the next iter, replace empty data with nearest available preds
|
| 342 |
+
invalid_idx = np.where(visits==0)[0]
|
| 343 |
+
valid_idx = np.where(visits>0)[0]
|
| 344 |
+
for idx in invalid_idx:
|
| 345 |
+
nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
|
| 346 |
+
# print('replacing %d with %d' % (idx, nearest))
|
| 347 |
+
full_flows8[:,idx] = full_flows8[:,nearest]
|
| 348 |
+
full_visconfs8[:,idx] = full_visconfs8[:,nearest]
|
| 349 |
+
if self.use_feats8:
|
| 350 |
+
full_feats8[:,idx] = full_feats8[:,nearest]
|
| 351 |
+
print("a8 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 352 |
+
else: # flow
|
| 353 |
+
|
| 354 |
+
flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 355 |
+
visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 356 |
+
|
| 357 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
| 358 |
+
fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
| 359 |
+
is_training=is_training)
|
| 360 |
+
unpad_flow_predictions = []
|
| 361 |
+
unpad_visconf_predictions = []
|
| 362 |
+
for i in range(len(flow_predictions)):
|
| 363 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
| 364 |
+
all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
|
| 365 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
| 366 |
+
all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
|
| 367 |
+
full_flows = all_flow_preds[-1].reshape(B,2,H,W)
|
| 368 |
+
full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W)
|
| 369 |
+
|
| 370 |
+
if (not is_training) and (T > 2):
|
| 371 |
+
full_flows = full_flows[:,:T_bak]
|
| 372 |
+
full_visconfs = full_visconfs[:,:T_bak]
|
| 373 |
+
print("a9 torch.cuda.memory_allocated: %.1fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
|
| 374 |
+
|
| 375 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
|
| 376 |
+
|
| 377 |
+
def forward_sliding(self, images, iters=4, sw=None, is_training=False, window_len=None, stride=None):
|
| 378 |
+
B,T,C,H,W = images.shape
|
| 379 |
+
S = self.seqlen if window_len is None else window_len
|
| 380 |
+
device = images.device
|
| 381 |
+
dtype = images.dtype
|
| 382 |
+
stride = S // 2 if stride is None else stride
|
| 383 |
+
|
| 384 |
+
T_bak = T
|
| 385 |
+
images, T, indices = self.get_T_padded_images(images, T, S, is_training, stride)
|
| 386 |
+
assert stride <= S // 2
|
| 387 |
+
|
| 388 |
+
images = images.contiguous()
|
| 389 |
+
images_ = images.reshape(B*T,3,H,W)
|
| 390 |
+
padder = InputPadder(images_.shape)
|
| 391 |
+
images_ = padder.pad(images_)[0]
|
| 392 |
+
|
| 393 |
+
_, _, H_pad, W_pad = images_.shape # revised HW
|
| 394 |
+
C, H8, W8 = self.dim*2, H_pad//8, W_pad//8
|
| 395 |
+
C2 = C//2
|
| 396 |
+
if self.no_split:
|
| 397 |
+
C = self.dim
|
| 398 |
+
C2 = C
|
| 399 |
+
|
| 400 |
+
all_flow_preds = None
|
| 401 |
+
all_visconf_preds = None
|
| 402 |
+
|
| 403 |
+
if T<=2:
|
| 404 |
+
# note: collecting preds can get expensive on a long video
|
| 405 |
+
all_flow_preds = []
|
| 406 |
+
all_visconf_preds = []
|
| 407 |
+
|
| 408 |
+
fmaps = self.get_fmaps(images_, B, T, sw, is_training).reshape(B,T,C,H8,W8)
|
| 409 |
+
device = fmaps.device
|
| 410 |
+
|
| 411 |
+
flows8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 412 |
+
visconfs8 = torch.zeros((B,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 413 |
+
|
| 414 |
+
fmap_anchor = fmaps[:,0]
|
| 415 |
+
|
| 416 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, feats8 = self.forward_window(
|
| 417 |
+
fmap_anchor, fmaps[:,1:2], visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
| 418 |
+
is_training=is_training)
|
| 419 |
+
unpad_flow_predictions = []
|
| 420 |
+
unpad_visconf_predictions = []
|
| 421 |
+
for i in range(len(flow_predictions)):
|
| 422 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
| 423 |
+
all_flow_preds.append(flow_predictions[i].reshape(B,2,H,W))
|
| 424 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
| 425 |
+
all_visconf_preds.append(visconf_predictions[i].reshape(B,2,H,W))
|
| 426 |
+
full_flows = all_flow_preds[-1].reshape(B,2,H,W).detach().cpu()
|
| 427 |
+
full_visconfs = all_visconf_preds[-1].reshape(B,2,H,W).detach().cpu()
|
| 428 |
+
|
| 429 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
|
| 430 |
+
|
| 431 |
+
assert T > 2 # multiframe tracking
|
| 432 |
+
|
| 433 |
+
if is_training:
|
| 434 |
+
all_flow_preds = []
|
| 435 |
+
all_visconf_preds = []
|
| 436 |
+
|
| 437 |
+
# we will store our final outputs in these cpu tensors
|
| 438 |
+
full_flows = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
|
| 439 |
+
full_visconfs = torch.zeros((B,T,2,H,W), dtype=dtype, device='cpu')
|
| 440 |
+
|
| 441 |
+
images_ = images_.reshape(B,T,3,H_pad,W_pad)
|
| 442 |
+
fmap_anchor = self.get_fmaps(images_[:,:1].reshape(-1,3,H_pad,W_pad), B, 1, sw, is_training).reshape(B,C,H8,W8)
|
| 443 |
+
device = fmap_anchor.device
|
| 444 |
+
full_visited = torch.zeros((T,), dtype=torch.bool, device=device)
|
| 445 |
+
|
| 446 |
+
for ii, ind in enumerate(indices):
|
| 447 |
+
ara = np.arange(ind,ind+S)
|
| 448 |
+
if ii == 0:
|
| 449 |
+
flows8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 450 |
+
visconfs8 = torch.zeros((B,S,2,H_pad//8,W_pad//8), dtype=dtype, device=device)
|
| 451 |
+
fmaps2 = self.get_fmaps(images_[:,ara].reshape(-1,3,H_pad,W_pad), B, S, sw, is_training).reshape(B,S,C,H8,W8)
|
| 452 |
+
else:
|
| 453 |
+
flows8 = torch.cat([flows8[:,stride:stride+S//2], flows8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
|
| 454 |
+
visconfs8 = torch.cat([visconfs8[:,stride:stride+S//2], visconfs8[:,stride+S//2-1:stride+S//2].repeat(1,S//2,1,1,1)], dim=1)
|
| 455 |
+
fmaps2 = torch.cat([fmaps2[:,stride:stride+S//2],
|
| 456 |
+
self.get_fmaps(images_[:,np.arange(ind+S//2,ind+S)].reshape(-1,3,H_pad,W_pad), B, S//2, sw, is_training).reshape(B,S//2,C,H8,W8)], dim=1)
|
| 457 |
+
|
| 458 |
+
flows8 = flows8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
|
| 459 |
+
visconfs8 = visconfs8.reshape(B*S,2,H_pad//8,W_pad//8).detach()
|
| 460 |
+
|
| 461 |
+
flow_predictions, visconf_predictions, flows8, visconfs8, _ = self.forward_window(
|
| 462 |
+
fmap_anchor, fmaps2, visconfs8, iters=iters, flowfeat=None, flows8=flows8,
|
| 463 |
+
is_training=is_training)
|
| 464 |
+
|
| 465 |
+
unpad_flow_predictions = []
|
| 466 |
+
unpad_visconf_predictions = []
|
| 467 |
+
for i in range(len(flow_predictions)):
|
| 468 |
+
flow_predictions[i] = padder.unpad(flow_predictions[i])
|
| 469 |
+
unpad_flow_predictions.append(flow_predictions[i].reshape(B,S,2,H,W))
|
| 470 |
+
visconf_predictions[i] = padder.unpad(torch.sigmoid(visconf_predictions[i]))
|
| 471 |
+
unpad_visconf_predictions.append(visconf_predictions[i].reshape(B,S,2,H,W))
|
| 472 |
+
|
| 473 |
+
current_visiting = torch.zeros((T,), dtype=torch.bool, device=device)
|
| 474 |
+
current_visiting[ara] = True
|
| 475 |
+
|
| 476 |
+
to_fill = current_visiting & (~full_visited)
|
| 477 |
+
to_fill_sum = to_fill.sum().item()
|
| 478 |
+
full_flows[:,to_fill] = unpad_flow_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
|
| 479 |
+
full_visconfs[:,to_fill] = unpad_visconf_predictions[-1].reshape(B,S,2,H,W)[:,-to_fill_sum:].detach().cpu()
|
| 480 |
+
full_visited |= current_visiting
|
| 481 |
+
|
| 482 |
+
if is_training:
|
| 483 |
+
all_flow_preds.append(unpad_flow_predictions)
|
| 484 |
+
all_visconf_preds.append(unpad_visconf_predictions)
|
| 485 |
+
else:
|
| 486 |
+
del unpad_flow_predictions
|
| 487 |
+
del unpad_visconf_predictions
|
| 488 |
+
|
| 489 |
+
flows8 = flows8.reshape(B,S,2,H_pad//8,W_pad//8)
|
| 490 |
+
visconfs8 = visconfs8.reshape(B,S,2,H_pad//8,W_pad//8)
|
| 491 |
+
|
| 492 |
+
if not is_training:
|
| 493 |
+
full_flows = full_flows[:,:T_bak]
|
| 494 |
+
full_visconfs = full_visconfs[:,:T_bak]
|
| 495 |
+
|
| 496 |
+
return full_flows, full_visconfs, all_flow_preds, all_visconf_preds
|
| 497 |
+
|
| 498 |
+
def forward_window(self, fmap1_single, fmaps2, visconfs8, iters=None, flowfeat=None, flows8=None, sw=None, is_training=False):
|
| 499 |
+
B,S,C,H8,W8 = fmaps2.shape
|
| 500 |
+
device = fmaps2.device
|
| 501 |
+
dtype = fmaps2.dtype
|
| 502 |
+
|
| 503 |
+
flow_predictions = []
|
| 504 |
+
visconf_predictions = []
|
| 505 |
+
|
| 506 |
+
fmap1 = fmap1_single.unsqueeze(1).repeat(1,S,1,1,1) # B,S,C,H,W
|
| 507 |
+
fmap1 = fmap1.reshape(B*(S),C,H8,W8).contiguous()
|
| 508 |
+
|
| 509 |
+
fmap2 = fmaps2.reshape(B*(S),C,H8,W8).contiguous()
|
| 510 |
+
|
| 511 |
+
visconfs8 = visconfs8.reshape(B*(S),2,H8,W8).contiguous()
|
| 512 |
+
|
| 513 |
+
corr_fn = CorrBlock(fmap1, fmap2, self.corr_levels, self.corr_radius)
|
| 514 |
+
|
| 515 |
+
coords1 = self.coords_grid(B*(S), H8, W8, device=fmap1.device, dtype=dtype)
|
| 516 |
+
|
| 517 |
+
if self.no_split:
|
| 518 |
+
flowfeat, ctxfeat = fmap1.clone(), fmap1.clone()
|
| 519 |
+
else:
|
| 520 |
+
if flowfeat is not None:
|
| 521 |
+
_, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
|
| 522 |
+
else:
|
| 523 |
+
flowfeat, ctxfeat = torch.split(fmap1, [self.dim, self.dim], dim=1)
|
| 524 |
+
|
| 525 |
+
# add pos emb to ctxfeat (and not flowfeat), since ctxfeat is untouched across iters
|
| 526 |
+
time_emb = self.fetch_time_embed(S, ctxfeat.dtype, is_training).reshape(1,S,self.dim,1,1).repeat(B,1,1,1,1)
|
| 527 |
+
ctxfeat = ctxfeat + time_emb.reshape(B*S,self.dim,1,1)
|
| 528 |
+
|
| 529 |
+
if self.no_ctx:
|
| 530 |
+
flowfeat = flowfeat + time_emb.reshape(B*S,self.dim,1,1)
|
| 531 |
+
|
| 532 |
+
for itr in range(iters):
|
| 533 |
+
_, _, H8, W8 = flows8.shape
|
| 534 |
+
flows8 = flows8.detach()
|
| 535 |
+
coords2 = (coords1 + flows8).detach() # B*S,2,H,W
|
| 536 |
+
corr = corr_fn(coords2).to(dtype)
|
| 537 |
+
|
| 538 |
+
if self.use_relmotion or self.use_sinrelmotion:
|
| 539 |
+
coords_ = coords2.reshape(B,S,2,H8*W8).permute(0,1,3,2) # B,S,H8*W8,2
|
| 540 |
+
rel_coords_forward = coords_[:, :-1] - coords_[:, 1:]
|
| 541 |
+
rel_coords_backward = coords_[:, 1:] - coords_[:, :-1]
|
| 542 |
+
rel_coords_forward = torch.nn.functional.pad(
|
| 543 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1) # pad the 3rd-last dim (S) by (0,1)
|
| 544 |
+
)
|
| 545 |
+
rel_coords_backward = torch.nn.functional.pad(
|
| 546 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0) # pad the 3rd-last dim (S) by (1,0)
|
| 547 |
+
)
|
| 548 |
+
rel_coords = torch.cat([rel_coords_forward, rel_coords_backward], dim=-1) # B,S,H8*W8,4
|
| 549 |
+
|
| 550 |
+
if self.use_sinrelmotion:
|
| 551 |
+
rel_pos_emb_input = utils.misc.posenc(
|
| 552 |
+
rel_coords,
|
| 553 |
+
min_deg=0,
|
| 554 |
+
max_deg=10,
|
| 555 |
+
) # B,S,H*W,pdim
|
| 556 |
+
motion = rel_pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
|
| 557 |
+
else:
|
| 558 |
+
motion = rel_coords.reshape(B*S,H8,W8,4).permute(0,3,1,2).to(dtype) # B*S,4,H8,W8
|
| 559 |
+
|
| 560 |
+
else:
|
| 561 |
+
if self.use_sinmotion:
|
| 562 |
+
pos_emb_input = utils.misc.posenc(
|
| 563 |
+
flows8.reshape(B,S,H8*W8,2),
|
| 564 |
+
min_deg=0,
|
| 565 |
+
max_deg=10,
|
| 566 |
+
) # B,S,H*W,pdim
|
| 567 |
+
motion = pos_emb_input.reshape(B*S,H8,W8,self.pdim).permute(0,3,1,2).to(dtype) # B*S,pdim,H8,W8
|
| 568 |
+
else:
|
| 569 |
+
motion = flows8
|
| 570 |
+
|
| 571 |
+
flowfeat = self.update_block(flowfeat, ctxfeat, visconfs8, corr, motion, S)
|
| 572 |
+
flow_update = self.flow_head(flowfeat)
|
| 573 |
+
visconf_update = self.visconf_head(flowfeat)
|
| 574 |
+
weight_update = .25 * self.upsample_weight(flowfeat)
|
| 575 |
+
flows8 = flows8 + flow_update
|
| 576 |
+
visconfs8 = visconfs8 + visconf_update
|
| 577 |
+
flow_up = self.upsample_data(flows8, weight_update)
|
| 578 |
+
visconf_up = self.upsample_data(visconfs8, weight_update)
|
| 579 |
+
if not is_training: # clear mem
|
| 580 |
+
flow_predictions = []
|
| 581 |
+
visconf_predictions = []
|
| 582 |
+
flow_predictions.append(flow_up)
|
| 583 |
+
visconf_predictions.append(visconf_up)
|
| 584 |
+
|
| 585 |
+
return flow_predictions, visconf_predictions, flows8, visconfs8, flowfeat
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
|
nets/blocks.py
ADDED
|
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
from itertools import repeat
|
| 6 |
+
import collections
|
| 7 |
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
|
| 8 |
+
from functools import partial
|
| 9 |
+
import einops
|
| 10 |
+
import math
|
| 11 |
+
from torchvision.ops.misc import Conv2dNormActivation, Permute
|
| 12 |
+
from torchvision.ops.stochastic_depth import StochasticDepth
|
| 13 |
+
|
| 14 |
+
def _ntuple(n):
|
| 15 |
+
def parse(x):
|
| 16 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 17 |
+
return tuple(x)
|
| 18 |
+
return tuple(repeat(x, n))
|
| 19 |
+
return parse
|
| 20 |
+
|
| 21 |
+
def exists(val):
|
| 22 |
+
return val is not None
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
return val if exists(val) else d
|
| 26 |
+
|
| 27 |
+
to_2tuple = _ntuple(2)
|
| 28 |
+
|
| 29 |
+
class InputPadder:
|
| 30 |
+
""" Pads images such that dimensions are divisible by a certain stride """
|
| 31 |
+
def __init__(self, dims, mode='sintel'):
|
| 32 |
+
self.ht, self.wd = dims[-2:]
|
| 33 |
+
pad_ht = (((self.ht // 64) + 1) * 64 - self.ht) % 64
|
| 34 |
+
pad_wd = (((self.wd // 64) + 1) * 64 - self.wd) % 64
|
| 35 |
+
if mode == 'sintel':
|
| 36 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 37 |
+
else:
|
| 38 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
| 39 |
+
|
| 40 |
+
def pad(self, *inputs):
|
| 41 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
| 42 |
+
|
| 43 |
+
def unpad(self, x):
|
| 44 |
+
ht, wd = x.shape[-2:]
|
| 45 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 46 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 47 |
+
|
| 48 |
+
def bilinear_sampler(
|
| 49 |
+
input, coords,
|
| 50 |
+
align_corners=True,
|
| 51 |
+
padding_mode="border",
|
| 52 |
+
normalize_coords=True):
|
| 53 |
+
# func from mattie (oct9)
|
| 54 |
+
if input.ndim not in [4, 5]:
|
| 55 |
+
raise ValueError("input must be 4D or 5D.")
|
| 56 |
+
|
| 57 |
+
if input.ndim == 4 and not coords.ndim == 4:
|
| 58 |
+
raise ValueError("input is 4D, but coords is not 4D.")
|
| 59 |
+
|
| 60 |
+
if input.ndim == 5 and not coords.ndim == 5:
|
| 61 |
+
raise ValueError("input is 5D, but coords is not 5D.")
|
| 62 |
+
|
| 63 |
+
if coords.ndim == 5:
|
| 64 |
+
coords = coords[..., [1, 2, 0]] # t x y -> x y t to match what grid_sample() expects.
|
| 65 |
+
|
| 66 |
+
if normalize_coords:
|
| 67 |
+
if align_corners:
|
| 68 |
+
# Normalize coordinates from [0, W/H - 1] to [-1, 1].
|
| 69 |
+
coords = (
|
| 70 |
+
coords
|
| 71 |
+
* torch.tensor([2 / max(size - 1, 1) for size in reversed(input.shape[2:])], device=coords.device)
|
| 72 |
+
- 1
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
# Normalize coordinates from [0, W/H] to [-1, 1].
|
| 76 |
+
coords = coords * torch.tensor([2 / size for size in reversed(input.shape[2:])], device=coords.device) - 1
|
| 77 |
+
|
| 78 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CorrBlock:
|
| 82 |
+
def __init__(self, fmap1, fmap2, corr_levels, corr_radius):
|
| 83 |
+
self.num_levels = corr_levels
|
| 84 |
+
self.radius = corr_radius
|
| 85 |
+
self.corr_pyramid = []
|
| 86 |
+
# all pairs correlation
|
| 87 |
+
for i in range(self.num_levels):
|
| 88 |
+
corr = CorrBlock.corr(fmap1, fmap2, 1)
|
| 89 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
| 90 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
| 91 |
+
fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='area')
|
| 92 |
+
# print('corr', corr.shape)
|
| 93 |
+
self.corr_pyramid.append(corr)
|
| 94 |
+
|
| 95 |
+
def __call__(self, coords, dilation=None):
|
| 96 |
+
r = self.radius
|
| 97 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 98 |
+
batch, h1, w1, _ = coords.shape
|
| 99 |
+
|
| 100 |
+
if dilation is None:
|
| 101 |
+
dilation = torch.ones(batch, 1, h1, w1, device=coords.device)
|
| 102 |
+
|
| 103 |
+
out_pyramid = []
|
| 104 |
+
for i in range(self.num_levels):
|
| 105 |
+
corr = self.corr_pyramid[i]
|
| 106 |
+
device = coords.device
|
| 107 |
+
dx = torch.linspace(-r, r, 2*r+1, device=device)
|
| 108 |
+
dy = torch.linspace(-r, r, 2*r+1, device=device)
|
| 109 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
| 110 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
| 111 |
+
delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1)
|
| 112 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
| 113 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 114 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
| 115 |
+
corr = corr.view(batch, h1, w1, -1)
|
| 116 |
+
out_pyramid.append(corr)
|
| 117 |
+
|
| 118 |
+
out = torch.cat(out_pyramid, dim=-1)
|
| 119 |
+
out = out.permute(0, 3, 1, 2).contiguous().float()
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def corr(fmap1, fmap2, num_head):
|
| 124 |
+
batch, dim, h1, w1 = fmap1.shape
|
| 125 |
+
h2, w2 = fmap2.shape[2:]
|
| 126 |
+
fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1)
|
| 127 |
+
fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2)
|
| 128 |
+
corr = fmap1.transpose(2, 3) @ fmap2
|
| 129 |
+
corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5)
|
| 130 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
| 131 |
+
|
| 132 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 133 |
+
"""1x1 convolution without padding"""
|
| 134 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
|
| 135 |
+
|
| 136 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 137 |
+
"""3x3 convolution with padding"""
|
| 138 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
|
| 139 |
+
|
| 140 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 141 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 142 |
+
x = x.permute(0, 2, 3, 1)
|
| 143 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 144 |
+
x = x.permute(0, 3, 1, 2)
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
class CNBlock1d(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
dim,
|
| 151 |
+
output_dim,
|
| 152 |
+
layer_scale: float = 1e-6,
|
| 153 |
+
stochastic_depth_prob: float = 0,
|
| 154 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 155 |
+
dense=True,
|
| 156 |
+
use_attn=True,
|
| 157 |
+
use_mixer=False,
|
| 158 |
+
use_conv=False,
|
| 159 |
+
use_convb=False,
|
| 160 |
+
use_layer_scale=True,
|
| 161 |
+
) -> None:
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.dense = dense
|
| 164 |
+
self.use_attn = use_attn
|
| 165 |
+
self.use_mixer = use_mixer
|
| 166 |
+
self.use_conv = use_conv
|
| 167 |
+
self.use_layer_scale = use_layer_scale
|
| 168 |
+
|
| 169 |
+
if use_attn:
|
| 170 |
+
assert not use_mixer
|
| 171 |
+
assert not use_conv
|
| 172 |
+
assert not use_convb
|
| 173 |
+
|
| 174 |
+
if norm_layer is None:
|
| 175 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 176 |
+
|
| 177 |
+
if use_attn:
|
| 178 |
+
num_heads = 8
|
| 179 |
+
self.block = AttnBlock(
|
| 180 |
+
hidden_size=dim,
|
| 181 |
+
num_heads=num_heads,
|
| 182 |
+
mlp_ratio=4,
|
| 183 |
+
attn_class=Attention,
|
| 184 |
+
)
|
| 185 |
+
elif use_mixer:
|
| 186 |
+
self.block = MLPMixerBlock(
|
| 187 |
+
S=16,
|
| 188 |
+
dim=dim,
|
| 189 |
+
depth=1,
|
| 190 |
+
expansion_factor=2,
|
| 191 |
+
)
|
| 192 |
+
elif use_conv:
|
| 193 |
+
self.block = nn.Sequential(
|
| 194 |
+
nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
|
| 195 |
+
Permute([0, 2, 1]),
|
| 196 |
+
norm_layer(dim),
|
| 197 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
| 198 |
+
nn.GELU(),
|
| 199 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
| 200 |
+
Permute([0, 2, 1]),
|
| 201 |
+
)
|
| 202 |
+
elif use_convb:
|
| 203 |
+
self.block = nn.Sequential(
|
| 204 |
+
nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=True, padding_mode='zeros'),
|
| 205 |
+
Permute([0, 2, 1]),
|
| 206 |
+
norm_layer(dim),
|
| 207 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
| 208 |
+
nn.GELU(),
|
| 209 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
| 210 |
+
Permute([0, 2, 1]),
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
assert(False) # choose attn, mixer, or conv please
|
| 214 |
+
|
| 215 |
+
if self.use_layer_scale:
|
| 216 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
|
| 217 |
+
else:
|
| 218 |
+
self.layer_scale = 1.0
|
| 219 |
+
|
| 220 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
| 221 |
+
|
| 222 |
+
if output_dim != dim:
|
| 223 |
+
self.final = nn.Conv1d(dim, output_dim, kernel_size=1, padding=0)
|
| 224 |
+
else:
|
| 225 |
+
self.final = nn.Identity()
|
| 226 |
+
|
| 227 |
+
def forward(self, input, S=None):
|
| 228 |
+
if self.dense:
|
| 229 |
+
assert S is not None
|
| 230 |
+
BS,C,H,W = input.shape
|
| 231 |
+
B = BS//S
|
| 232 |
+
|
| 233 |
+
input = einops.rearrange(input, '(b s) c h w -> (b h w) c s', b=B, s=S, c=C, h=H, w=W)
|
| 234 |
+
|
| 235 |
+
if self.use_mixer or self.use_attn:
|
| 236 |
+
# mixer/transformer blocks want B,S,C
|
| 237 |
+
result = self.layer_scale * self.block(input.permute(0,2,1)).permute(0,2,1)
|
| 238 |
+
else:
|
| 239 |
+
result = self.layer_scale * self.block(input)
|
| 240 |
+
result = self.stochastic_depth(result)
|
| 241 |
+
result += input
|
| 242 |
+
result = self.final(result)
|
| 243 |
+
|
| 244 |
+
result = einops.rearrange(result, '(b h w) c s -> (b s) c h w', b=B, s=S, c=C, h=H, w=W)
|
| 245 |
+
else:
|
| 246 |
+
B,S,C = input.shape
|
| 247 |
+
|
| 248 |
+
if S<7:
|
| 249 |
+
return input
|
| 250 |
+
|
| 251 |
+
input = einops.rearrange(input, 'b s c -> b c s', b=B, s=S, c=C)
|
| 252 |
+
|
| 253 |
+
result = self.layer_scale * self.block(input)
|
| 254 |
+
result = self.stochastic_depth(result)
|
| 255 |
+
result += input
|
| 256 |
+
|
| 257 |
+
result = self.final(result)
|
| 258 |
+
|
| 259 |
+
result = einops.rearrange(result, 'b c s -> b s c', b=B, s=S, c=C)
|
| 260 |
+
|
| 261 |
+
return result
|
| 262 |
+
|
| 263 |
+
class CNBlock2d(nn.Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
dim,
|
| 267 |
+
output_dim,
|
| 268 |
+
layer_scale: float = 1e-6,
|
| 269 |
+
stochastic_depth_prob: float = 0,
|
| 270 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 271 |
+
use_layer_scale=True,
|
| 272 |
+
) -> None:
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.use_layer_scale = use_layer_scale
|
| 275 |
+
if norm_layer is None:
|
| 276 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 277 |
+
|
| 278 |
+
self.block = nn.Sequential(
|
| 279 |
+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True, padding_mode='zeros'),
|
| 280 |
+
Permute([0, 2, 3, 1]),
|
| 281 |
+
norm_layer(dim),
|
| 282 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
| 283 |
+
nn.GELU(),
|
| 284 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
| 285 |
+
Permute([0, 3, 1, 2]),
|
| 286 |
+
)
|
| 287 |
+
if self.use_layer_scale:
|
| 288 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
| 289 |
+
else:
|
| 290 |
+
self.layer_scale = 1.0
|
| 291 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
| 292 |
+
|
| 293 |
+
if output_dim != dim:
|
| 294 |
+
self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
|
| 295 |
+
else:
|
| 296 |
+
self.final = nn.Identity()
|
| 297 |
+
|
| 298 |
+
def forward(self, input, S=None):
|
| 299 |
+
result = self.layer_scale * self.block(input)
|
| 300 |
+
result = self.stochastic_depth(result)
|
| 301 |
+
result += input
|
| 302 |
+
result = self.final(result)
|
| 303 |
+
return result
|
| 304 |
+
|
| 305 |
+
class CNBlockConfig:
|
| 306 |
+
# Stores information listed at Section 3 of the ConvNeXt paper
|
| 307 |
+
def __init__(
|
| 308 |
+
self,
|
| 309 |
+
input_channels: int,
|
| 310 |
+
out_channels: Optional[int],
|
| 311 |
+
num_layers: int,
|
| 312 |
+
downsample: bool,
|
| 313 |
+
) -> None:
|
| 314 |
+
self.input_channels = input_channels
|
| 315 |
+
self.out_channels = out_channels
|
| 316 |
+
self.num_layers = num_layers
|
| 317 |
+
self.downsample = downsample
|
| 318 |
+
|
| 319 |
+
def __repr__(self) -> str:
|
| 320 |
+
s = self.__class__.__name__ + "("
|
| 321 |
+
s += "input_channels={input_channels}"
|
| 322 |
+
s += ", out_channels={out_channels}"
|
| 323 |
+
s += ", num_layers={num_layers}"
|
| 324 |
+
s += ", downsample={downsample}"
|
| 325 |
+
s += ")"
|
| 326 |
+
return s.format(**self.__dict__)
|
| 327 |
+
|
| 328 |
+
class ConvNeXt(nn.Module):
|
| 329 |
+
def __init__(
|
| 330 |
+
self,
|
| 331 |
+
block_setting: List[CNBlockConfig],
|
| 332 |
+
stochastic_depth_prob: float = 0.0,
|
| 333 |
+
layer_scale: float = 1e-6,
|
| 334 |
+
num_classes: int = 1000,
|
| 335 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
| 336 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 337 |
+
init_weights=True):
|
| 338 |
+
super().__init__()
|
| 339 |
+
|
| 340 |
+
self.init_weights = init_weights
|
| 341 |
+
|
| 342 |
+
if not block_setting:
|
| 343 |
+
raise ValueError("The block_setting should not be empty")
|
| 344 |
+
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
|
| 345 |
+
raise TypeError("The block_setting should be List[CNBlockConfig]")
|
| 346 |
+
|
| 347 |
+
if block is None:
|
| 348 |
+
block = CNBlock2d
|
| 349 |
+
|
| 350 |
+
if norm_layer is None:
|
| 351 |
+
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
| 352 |
+
|
| 353 |
+
layers: List[nn.Module] = []
|
| 354 |
+
|
| 355 |
+
# Stem
|
| 356 |
+
firstconv_output_channels = block_setting[0].input_channels
|
| 357 |
+
layers.append(
|
| 358 |
+
Conv2dNormActivation(
|
| 359 |
+
3,
|
| 360 |
+
firstconv_output_channels,
|
| 361 |
+
kernel_size=4,
|
| 362 |
+
stride=4,
|
| 363 |
+
padding=0,
|
| 364 |
+
norm_layer=norm_layer,
|
| 365 |
+
activation_layer=None,
|
| 366 |
+
bias=True,
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
| 371 |
+
stage_block_id = 0
|
| 372 |
+
for cnf in block_setting:
|
| 373 |
+
# Bottlenecks
|
| 374 |
+
stage: List[nn.Module] = []
|
| 375 |
+
for _ in range(cnf.num_layers):
|
| 376 |
+
# adjust stochastic depth probability based on the depth of the stage block
|
| 377 |
+
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
| 378 |
+
stage.append(block(cnf.input_channels, cnf.input_channels, layer_scale, sd_prob))
|
| 379 |
+
stage_block_id += 1
|
| 380 |
+
layers.append(nn.Sequential(*stage))
|
| 381 |
+
if cnf.out_channels is not None:
|
| 382 |
+
if cnf.downsample:
|
| 383 |
+
layers.append(
|
| 384 |
+
nn.Sequential(
|
| 385 |
+
norm_layer(cnf.input_channels),
|
| 386 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
|
| 387 |
+
)
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
# we convert the 2x2 downsampling layer into a 3x3 with dilation2 and replicate padding.
|
| 391 |
+
# replicate padding compensates for the fact that this kernel never saw zero-padding.
|
| 392 |
+
layers.append(
|
| 393 |
+
nn.Sequential(
|
| 394 |
+
norm_layer(cnf.input_channels),
|
| 395 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=3, stride=1, padding=2, dilation=2, padding_mode='zeros'),
|
| 396 |
+
)
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
self.features = nn.Sequential(*layers)
|
| 400 |
+
|
| 401 |
+
# self.final_conv = conv1x1(block_setting[-1].input_channels, output_dim)
|
| 402 |
+
|
| 403 |
+
for m in self.modules():
|
| 404 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 405 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 406 |
+
if m.bias is not None:
|
| 407 |
+
nn.init.zeros_(m.bias)
|
| 408 |
+
|
| 409 |
+
if self.init_weights:
|
| 410 |
+
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
|
| 411 |
+
pretrained_dict = convnext_tiny(weights=ConvNeXt_Tiny_Weights.DEFAULT).state_dict()
|
| 412 |
+
# from torchvision.models import convnext_base, ConvNeXt_Base_Weights
|
| 413 |
+
# pretrained_dict = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT).state_dict()
|
| 414 |
+
model_dict = self.state_dict()
|
| 415 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
| 416 |
+
|
| 417 |
+
for k, v in pretrained_dict.items():
|
| 418 |
+
if k == 'features.4.1.weight': # this is the layer normally in charge of 2x2 downsampling
|
| 419 |
+
# convert to 3x3 filter
|
| 420 |
+
pretrained_dict[k] = F.interpolate(v, (3, 3), mode='bicubic', align_corners=True) * (4/9.0)
|
| 421 |
+
|
| 422 |
+
model_dict.update(pretrained_dict)
|
| 423 |
+
self.load_state_dict(model_dict, strict=False)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 427 |
+
x = self.features(x)
|
| 428 |
+
# x = self.final_conv(x)
|
| 429 |
+
return x
|
| 430 |
+
|
| 431 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 432 |
+
return self._forward_impl(x)
|
| 433 |
+
|
| 434 |
+
class Mlp(nn.Module):
|
| 435 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 436 |
+
|
| 437 |
+
def __init__(
|
| 438 |
+
self,
|
| 439 |
+
in_features,
|
| 440 |
+
hidden_features=None,
|
| 441 |
+
out_features=None,
|
| 442 |
+
act_layer=nn.GELU,
|
| 443 |
+
norm_layer=None,
|
| 444 |
+
bias=True,
|
| 445 |
+
drop=0.0,
|
| 446 |
+
use_conv=False,
|
| 447 |
+
):
|
| 448 |
+
super().__init__()
|
| 449 |
+
out_features = out_features or in_features
|
| 450 |
+
hidden_features = hidden_features or in_features
|
| 451 |
+
bias = to_2tuple(bias)
|
| 452 |
+
drop_probs = to_2tuple(drop)
|
| 453 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 454 |
+
|
| 455 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 456 |
+
self.act = act_layer()
|
| 457 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 458 |
+
self.norm = (
|
| 459 |
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 460 |
+
)
|
| 461 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 462 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 463 |
+
|
| 464 |
+
def forward(self, x):
|
| 465 |
+
x = self.fc1(x)
|
| 466 |
+
x = self.act(x)
|
| 467 |
+
x = self.drop1(x)
|
| 468 |
+
x = self.fc2(x)
|
| 469 |
+
x = self.drop2(x)
|
| 470 |
+
return x
|
| 471 |
+
|
| 472 |
+
class Attention(nn.Module):
|
| 473 |
+
def __init__(
|
| 474 |
+
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
|
| 475 |
+
):
|
| 476 |
+
super().__init__()
|
| 477 |
+
inner_dim = dim_head * num_heads
|
| 478 |
+
context_dim = default(context_dim, query_dim)
|
| 479 |
+
self.scale = dim_head**-0.5
|
| 480 |
+
self.heads = num_heads
|
| 481 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
| 482 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
| 483 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
| 484 |
+
|
| 485 |
+
def forward(self, x, context=None, attn_bias=None):
|
| 486 |
+
B, N1, C = x.shape
|
| 487 |
+
H = self.heads
|
| 488 |
+
q = self.to_q(x)
|
| 489 |
+
context = default(context, x)
|
| 490 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 491 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
|
| 492 |
+
x = F.scaled_dot_product_attention(q, k, v) # scale default is already dim^-0.5
|
| 493 |
+
x = einops.rearrange(x, 'b h n d -> b n (h d)')
|
| 494 |
+
return self.to_out(x)
|
| 495 |
+
|
| 496 |
+
class CrossAttnBlock(nn.Module):
|
| 497 |
+
def __init__(
|
| 498 |
+
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
|
| 499 |
+
):
|
| 500 |
+
super().__init__()
|
| 501 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 502 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 503 |
+
self.cross_attn = Attention(
|
| 504 |
+
hidden_size,
|
| 505 |
+
context_dim=context_dim,
|
| 506 |
+
num_heads=num_heads,
|
| 507 |
+
qkv_bias=True,
|
| 508 |
+
**block_kwargs
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 512 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 513 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 514 |
+
self.mlp = Mlp(
|
| 515 |
+
in_features=hidden_size,
|
| 516 |
+
hidden_features=mlp_hidden_dim,
|
| 517 |
+
act_layer=approx_gelu,
|
| 518 |
+
drop=0,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
def forward(self, x, context, mask=None):
|
| 522 |
+
attn_bias = None
|
| 523 |
+
if mask is not None:
|
| 524 |
+
if mask.shape[1] == x.shape[1]:
|
| 525 |
+
mask = mask[:, None, :, None].expand(
|
| 526 |
+
-1, self.cross_attn.heads, -1, context.shape[1]
|
| 527 |
+
)
|
| 528 |
+
else:
|
| 529 |
+
mask = mask[:, None, None].expand(
|
| 530 |
+
-1, self.cross_attn.heads, x.shape[1], -1
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 534 |
+
attn_bias = (~mask) * max_neg_value
|
| 535 |
+
x = x + self.cross_attn(
|
| 536 |
+
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
| 537 |
+
)
|
| 538 |
+
x = x + self.mlp(self.norm2(x))
|
| 539 |
+
return x
|
| 540 |
+
|
| 541 |
+
class AttnBlock(nn.Module):
|
| 542 |
+
def __init__(
|
| 543 |
+
self,
|
| 544 |
+
hidden_size,
|
| 545 |
+
num_heads,
|
| 546 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 547 |
+
mlp_ratio=4.0,
|
| 548 |
+
**block_kwargs
|
| 549 |
+
):
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 552 |
+
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, dim_head=hidden_size//num_heads)
|
| 553 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 554 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 555 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 556 |
+
self.mlp = Mlp(
|
| 557 |
+
in_features=hidden_size,
|
| 558 |
+
hidden_features=mlp_hidden_dim,
|
| 559 |
+
act_layer=approx_gelu,
|
| 560 |
+
drop=0,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
def forward(self, x, mask=None):
|
| 564 |
+
attn_bias = mask
|
| 565 |
+
if mask is not None:
|
| 566 |
+
mask = (
|
| 567 |
+
(mask[:, None] * mask[:, :, None])
|
| 568 |
+
.unsqueeze(1)
|
| 569 |
+
.expand(-1, self.attn.num_heads, -1, -1)
|
| 570 |
+
)
|
| 571 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
| 572 |
+
attn_bias = (~mask) * max_neg_value
|
| 573 |
+
|
| 574 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 575 |
+
x = x + self.mlp(self.norm2(x))
|
| 576 |
+
return x
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class ResidualBlock(nn.Module):
|
| 580 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 581 |
+
super(ResidualBlock, self).__init__()
|
| 582 |
+
|
| 583 |
+
self.conv1 = nn.Conv2d(
|
| 584 |
+
in_planes,
|
| 585 |
+
planes,
|
| 586 |
+
kernel_size=3,
|
| 587 |
+
padding=1,
|
| 588 |
+
stride=stride,
|
| 589 |
+
padding_mode="zeros",
|
| 590 |
+
)
|
| 591 |
+
self.conv2 = nn.Conv2d(
|
| 592 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
| 593 |
+
)
|
| 594 |
+
self.relu = nn.ReLU(inplace=True)
|
| 595 |
+
|
| 596 |
+
num_groups = planes // 8
|
| 597 |
+
|
| 598 |
+
if norm_fn == "group":
|
| 599 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 600 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 601 |
+
if not stride == 1:
|
| 602 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 603 |
+
|
| 604 |
+
elif norm_fn == "batch":
|
| 605 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 606 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 607 |
+
if not stride == 1:
|
| 608 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 609 |
+
|
| 610 |
+
elif norm_fn == "instance":
|
| 611 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 612 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 613 |
+
if not stride == 1:
|
| 614 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 615 |
+
|
| 616 |
+
elif norm_fn == "none":
|
| 617 |
+
self.norm1 = nn.Sequential()
|
| 618 |
+
self.norm2 = nn.Sequential()
|
| 619 |
+
if not stride == 1:
|
| 620 |
+
self.norm3 = nn.Sequential()
|
| 621 |
+
|
| 622 |
+
if stride == 1:
|
| 623 |
+
self.downsample = None
|
| 624 |
+
|
| 625 |
+
else:
|
| 626 |
+
self.downsample = nn.Sequential(
|
| 627 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
def forward(self, x):
|
| 631 |
+
y = x
|
| 632 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 633 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 634 |
+
|
| 635 |
+
if self.downsample is not None:
|
| 636 |
+
x = self.downsample(x)
|
| 637 |
+
|
| 638 |
+
return self.relu(x + y)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
class BasicEncoder(nn.Module):
|
| 642 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 643 |
+
super(BasicEncoder, self).__init__()
|
| 644 |
+
self.stride = stride
|
| 645 |
+
self.norm_fn = "instance"
|
| 646 |
+
self.in_planes = output_dim // 2
|
| 647 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 648 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 649 |
+
|
| 650 |
+
self.conv1 = nn.Conv2d(
|
| 651 |
+
input_dim,
|
| 652 |
+
self.in_planes,
|
| 653 |
+
kernel_size=7,
|
| 654 |
+
stride=2,
|
| 655 |
+
padding=3,
|
| 656 |
+
padding_mode="zeros",
|
| 657 |
+
)
|
| 658 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 659 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 660 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 661 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
| 662 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
| 663 |
+
|
| 664 |
+
self.conv2 = nn.Conv2d(
|
| 665 |
+
output_dim * 3 + output_dim // 4,
|
| 666 |
+
output_dim * 2,
|
| 667 |
+
kernel_size=3,
|
| 668 |
+
padding=1,
|
| 669 |
+
padding_mode="zeros",
|
| 670 |
+
)
|
| 671 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 672 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
| 673 |
+
for m in self.modules():
|
| 674 |
+
if isinstance(m, nn.Conv2d):
|
| 675 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 676 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 677 |
+
if m.weight is not None:
|
| 678 |
+
nn.init.constant_(m.weight, 1)
|
| 679 |
+
if m.bias is not None:
|
| 680 |
+
nn.init.constant_(m.bias, 0)
|
| 681 |
+
|
| 682 |
+
def _make_layer(self, dim, stride=1):
|
| 683 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 684 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 685 |
+
layers = (layer1, layer2)
|
| 686 |
+
|
| 687 |
+
self.in_planes = dim
|
| 688 |
+
return nn.Sequential(*layers)
|
| 689 |
+
|
| 690 |
+
def forward(self, x):
|
| 691 |
+
_, _, H, W = x.shape
|
| 692 |
+
|
| 693 |
+
x = self.conv1(x)
|
| 694 |
+
x = self.norm1(x)
|
| 695 |
+
x = self.relu1(x)
|
| 696 |
+
|
| 697 |
+
a = self.layer1(x)
|
| 698 |
+
b = self.layer2(a)
|
| 699 |
+
c = self.layer3(b)
|
| 700 |
+
d = self.layer4(c)
|
| 701 |
+
|
| 702 |
+
def _bilinear_intepolate(x):
|
| 703 |
+
return F.interpolate(
|
| 704 |
+
x,
|
| 705 |
+
(H // self.stride, W // self.stride),
|
| 706 |
+
mode="bilinear",
|
| 707 |
+
align_corners=True,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
a = _bilinear_intepolate(a)
|
| 711 |
+
b = _bilinear_intepolate(b)
|
| 712 |
+
c = _bilinear_intepolate(c)
|
| 713 |
+
d = _bilinear_intepolate(d)
|
| 714 |
+
|
| 715 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
| 716 |
+
x = self.norm2(x)
|
| 717 |
+
x = self.relu2(x)
|
| 718 |
+
x = self.conv3(x)
|
| 719 |
+
return x
|
| 720 |
+
|
| 721 |
+
class EfficientUpdateFormer(nn.Module):
|
| 722 |
+
"""
|
| 723 |
+
Transformer model that updates track estimates.
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
def __init__(
|
| 727 |
+
self,
|
| 728 |
+
space_depth=6,
|
| 729 |
+
time_depth=6,
|
| 730 |
+
input_dim=320,
|
| 731 |
+
hidden_size=384,
|
| 732 |
+
num_heads=8,
|
| 733 |
+
output_dim=130,
|
| 734 |
+
mlp_ratio=4.0,
|
| 735 |
+
num_virtual_tracks=64,
|
| 736 |
+
add_space_attn=True,
|
| 737 |
+
linear_layer_for_vis_conf=False,
|
| 738 |
+
use_time_conv=False,
|
| 739 |
+
use_time_mixer=False,
|
| 740 |
+
):
|
| 741 |
+
super().__init__()
|
| 742 |
+
self.out_channels = 2
|
| 743 |
+
self.num_heads = num_heads
|
| 744 |
+
self.hidden_size = hidden_size
|
| 745 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 746 |
+
if linear_layer_for_vis_conf:
|
| 747 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
| 748 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
| 749 |
+
else:
|
| 750 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 751 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 752 |
+
self.virual_tracks = nn.Parameter(
|
| 753 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
| 754 |
+
)
|
| 755 |
+
self.add_space_attn = add_space_attn
|
| 756 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
| 757 |
+
|
| 758 |
+
if use_time_conv:
|
| 759 |
+
self.time_blocks = nn.ModuleList(
|
| 760 |
+
[
|
| 761 |
+
CNBlock1d(hidden_size, hidden_size, dense=False)
|
| 762 |
+
for _ in range(time_depth)
|
| 763 |
+
]
|
| 764 |
+
)
|
| 765 |
+
elif use_time_mixer:
|
| 766 |
+
self.time_blocks = nn.ModuleList(
|
| 767 |
+
[
|
| 768 |
+
MLPMixerBlock(
|
| 769 |
+
S=16,
|
| 770 |
+
dim=hidden_size,
|
| 771 |
+
depth=1,
|
| 772 |
+
)
|
| 773 |
+
for _ in range(time_depth)
|
| 774 |
+
]
|
| 775 |
+
)
|
| 776 |
+
else:
|
| 777 |
+
self.time_blocks = nn.ModuleList(
|
| 778 |
+
[
|
| 779 |
+
AttnBlock(
|
| 780 |
+
hidden_size,
|
| 781 |
+
num_heads,
|
| 782 |
+
mlp_ratio=mlp_ratio,
|
| 783 |
+
attn_class=Attention,
|
| 784 |
+
)
|
| 785 |
+
for _ in range(time_depth)
|
| 786 |
+
]
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
if add_space_attn:
|
| 790 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 791 |
+
[
|
| 792 |
+
AttnBlock(
|
| 793 |
+
hidden_size,
|
| 794 |
+
num_heads,
|
| 795 |
+
mlp_ratio=mlp_ratio,
|
| 796 |
+
attn_class=Attention,
|
| 797 |
+
)
|
| 798 |
+
for _ in range(space_depth)
|
| 799 |
+
]
|
| 800 |
+
)
|
| 801 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 802 |
+
[
|
| 803 |
+
CrossAttnBlock(
|
| 804 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 805 |
+
)
|
| 806 |
+
for _ in range(space_depth)
|
| 807 |
+
]
|
| 808 |
+
)
|
| 809 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 810 |
+
[
|
| 811 |
+
CrossAttnBlock(
|
| 812 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 813 |
+
)
|
| 814 |
+
for _ in range(space_depth)
|
| 815 |
+
]
|
| 816 |
+
)
|
| 817 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 818 |
+
self.initialize_weights()
|
| 819 |
+
|
| 820 |
+
def initialize_weights(self):
|
| 821 |
+
def _basic_init(module):
|
| 822 |
+
if isinstance(module, nn.Linear):
|
| 823 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 824 |
+
if module.bias is not None:
|
| 825 |
+
nn.init.constant_(module.bias, 0)
|
| 826 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
| 827 |
+
if self.linear_layer_for_vis_conf:
|
| 828 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
| 829 |
+
|
| 830 |
+
def _trunc_init(module):
|
| 831 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 832 |
+
if isinstance(module, nn.Linear):
|
| 833 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
| 834 |
+
if module.bias is not None:
|
| 835 |
+
nn.init.zeros_(module.bias)
|
| 836 |
+
|
| 837 |
+
self.apply(_basic_init)
|
| 838 |
+
|
| 839 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True):
|
| 840 |
+
tokens = self.input_transform(input_tensor)
|
| 841 |
+
|
| 842 |
+
B, _, T, _ = tokens.shape
|
| 843 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 844 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 845 |
+
|
| 846 |
+
_, N, _, _ = tokens.shape
|
| 847 |
+
j = 0
|
| 848 |
+
layers = []
|
| 849 |
+
for i in range(len(self.time_blocks)):
|
| 850 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 851 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
| 852 |
+
|
| 853 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 854 |
+
if (
|
| 855 |
+
add_space_attn
|
| 856 |
+
and hasattr(self, "space_virtual_blocks")
|
| 857 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
| 858 |
+
):
|
| 859 |
+
space_tokens = (
|
| 860 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
| 861 |
+
) # B N T C -> (B T) N C
|
| 862 |
+
|
| 863 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 864 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 865 |
+
|
| 866 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
| 867 |
+
virtual_tokens, point_tokens, mask=mask
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
| 871 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
| 872 |
+
point_tokens, virtual_tokens, mask=mask
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 876 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
| 877 |
+
0, 2, 1, 3
|
| 878 |
+
) # (B T) N C -> B N T C
|
| 879 |
+
j += 1
|
| 880 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 881 |
+
|
| 882 |
+
flow = self.flow_head(tokens)
|
| 883 |
+
if self.linear_layer_for_vis_conf:
|
| 884 |
+
vis_conf = self.vis_conf_head(tokens)
|
| 885 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
| 886 |
+
|
| 887 |
+
return flow
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
class MMPreNormResidual(nn.Module):
|
| 891 |
+
def __init__(self, dim, fn):
|
| 892 |
+
super().__init__()
|
| 893 |
+
self.fn = fn
|
| 894 |
+
self.norm = nn.LayerNorm(dim)
|
| 895 |
+
|
| 896 |
+
def forward(self, x):
|
| 897 |
+
return self.fn(self.norm(x)) + x
|
| 898 |
+
|
| 899 |
+
def MMFeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
|
| 900 |
+
return nn.Sequential(
|
| 901 |
+
dense(dim, dim * expansion_factor),
|
| 902 |
+
nn.GELU(),
|
| 903 |
+
nn.Dropout(dropout),
|
| 904 |
+
dense(dim * expansion_factor, dim),
|
| 905 |
+
nn.Dropout(dropout)
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
def MLPMixer(S, input_dim, dim, output_dim, depth=6, expansion_factor=4, dropout=0., do_reduce=False):
|
| 909 |
+
# input is coming in as B,S,C, as standard for mlp and transformer
|
| 910 |
+
# chan_first treats S as the channel dim, and transforms it to a new S
|
| 911 |
+
# chan_last treats C as the channel dim, and transforms it to a new C
|
| 912 |
+
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
| 913 |
+
if do_reduce:
|
| 914 |
+
return nn.Sequential(
|
| 915 |
+
nn.Linear(input_dim, dim),
|
| 916 |
+
*[nn.Sequential(
|
| 917 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
| 918 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
| 919 |
+
) for _ in range(depth)],
|
| 920 |
+
nn.LayerNorm(dim),
|
| 921 |
+
Reduce('b n c -> b c', 'mean'),
|
| 922 |
+
nn.Linear(dim, output_dim)
|
| 923 |
+
)
|
| 924 |
+
else:
|
| 925 |
+
return nn.Sequential(
|
| 926 |
+
nn.Linear(input_dim, dim),
|
| 927 |
+
*[nn.Sequential(
|
| 928 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
| 929 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
| 930 |
+
) for _ in range(depth)],
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
def MLPMixerBlock(S, dim, depth=1, expansion_factor=4, dropout=0., do_reduce=False):
|
| 934 |
+
# input is coming in as B,S,C, as standard for mlp and transformer
|
| 935 |
+
# chan_first treats S as the channel dim, and transforms it to a new S
|
| 936 |
+
# chan_last treats C as the channel dim, and transforms it to a new C
|
| 937 |
+
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
|
| 938 |
+
return nn.Sequential(
|
| 939 |
+
*[nn.Sequential(
|
| 940 |
+
MMPreNormResidual(dim, MMFeedForward(S, expansion_factor, dropout, chan_first)),
|
| 941 |
+
MMPreNormResidual(dim, MMFeedForward(dim, expansion_factor, dropout, chan_last))
|
| 942 |
+
) for _ in range(depth)],
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
class MlpUpdateFormer(nn.Module):
|
| 947 |
+
"""
|
| 948 |
+
Transformer model that updates track estimates.
|
| 949 |
+
"""
|
| 950 |
+
|
| 951 |
+
def __init__(
|
| 952 |
+
self,
|
| 953 |
+
space_depth=6,
|
| 954 |
+
time_depth=6,
|
| 955 |
+
input_dim=320,
|
| 956 |
+
hidden_size=384,
|
| 957 |
+
num_heads=8,
|
| 958 |
+
output_dim=130,
|
| 959 |
+
mlp_ratio=4.0,
|
| 960 |
+
num_virtual_tracks=64,
|
| 961 |
+
add_space_attn=True,
|
| 962 |
+
linear_layer_for_vis_conf=False,
|
| 963 |
+
):
|
| 964 |
+
super().__init__()
|
| 965 |
+
self.out_channels = 2
|
| 966 |
+
self.num_heads = num_heads
|
| 967 |
+
self.hidden_size = hidden_size
|
| 968 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 969 |
+
if linear_layer_for_vis_conf:
|
| 970 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
| 971 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
| 972 |
+
else:
|
| 973 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 974 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 975 |
+
self.virual_tracks = nn.Parameter(
|
| 976 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
| 977 |
+
)
|
| 978 |
+
self.add_space_attn = add_space_attn
|
| 979 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
| 980 |
+
self.time_blocks = nn.ModuleList(
|
| 981 |
+
[
|
| 982 |
+
MLPMixer(
|
| 983 |
+
S=16,
|
| 984 |
+
input_dim=hidden_size,
|
| 985 |
+
dim=hidden_size,
|
| 986 |
+
output_dim=hidden_size,
|
| 987 |
+
depth=1,
|
| 988 |
+
)
|
| 989 |
+
for _ in range(time_depth)
|
| 990 |
+
]
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
if add_space_attn:
|
| 994 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 995 |
+
[
|
| 996 |
+
AttnBlock(
|
| 997 |
+
hidden_size,
|
| 998 |
+
num_heads,
|
| 999 |
+
mlp_ratio=mlp_ratio,
|
| 1000 |
+
attn_class=Attention,
|
| 1001 |
+
)
|
| 1002 |
+
for _ in range(space_depth)
|
| 1003 |
+
]
|
| 1004 |
+
)
|
| 1005 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 1006 |
+
[
|
| 1007 |
+
CrossAttnBlock(
|
| 1008 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 1009 |
+
)
|
| 1010 |
+
for _ in range(space_depth)
|
| 1011 |
+
]
|
| 1012 |
+
)
|
| 1013 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 1014 |
+
[
|
| 1015 |
+
CrossAttnBlock(
|
| 1016 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 1017 |
+
)
|
| 1018 |
+
for _ in range(space_depth)
|
| 1019 |
+
]
|
| 1020 |
+
)
|
| 1021 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 1022 |
+
self.initialize_weights()
|
| 1023 |
+
|
| 1024 |
+
def initialize_weights(self):
|
| 1025 |
+
def _basic_init(module):
|
| 1026 |
+
if isinstance(module, nn.Linear):
|
| 1027 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 1028 |
+
if module.bias is not None:
|
| 1029 |
+
nn.init.constant_(module.bias, 0)
|
| 1030 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
| 1031 |
+
if self.linear_layer_for_vis_conf:
|
| 1032 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
| 1033 |
+
|
| 1034 |
+
def _trunc_init(module):
|
| 1035 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 1036 |
+
if isinstance(module, nn.Linear):
|
| 1037 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
| 1038 |
+
if module.bias is not None:
|
| 1039 |
+
nn.init.zeros_(module.bias)
|
| 1040 |
+
|
| 1041 |
+
self.apply(_basic_init)
|
| 1042 |
+
|
| 1043 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True):
|
| 1044 |
+
tokens = self.input_transform(input_tensor)
|
| 1045 |
+
|
| 1046 |
+
B, _, T, _ = tokens.shape
|
| 1047 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 1048 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 1049 |
+
|
| 1050 |
+
_, N, _, _ = tokens.shape
|
| 1051 |
+
j = 0
|
| 1052 |
+
layers = []
|
| 1053 |
+
for i in range(len(self.time_blocks)):
|
| 1054 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 1055 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
| 1056 |
+
|
| 1057 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 1058 |
+
if (
|
| 1059 |
+
add_space_attn
|
| 1060 |
+
and hasattr(self, "space_virtual_blocks")
|
| 1061 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
| 1062 |
+
):
|
| 1063 |
+
space_tokens = (
|
| 1064 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
| 1065 |
+
) # B N T C -> (B T) N C
|
| 1066 |
+
|
| 1067 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 1068 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 1069 |
+
|
| 1070 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
| 1071 |
+
virtual_tokens, point_tokens, mask=mask
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
| 1075 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
| 1076 |
+
point_tokens, virtual_tokens, mask=mask
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 1080 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
| 1081 |
+
0, 2, 1, 3
|
| 1082 |
+
) # (B T) N C -> B N T C
|
| 1083 |
+
j += 1
|
| 1084 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 1085 |
+
|
| 1086 |
+
flow = self.flow_head(tokens)
|
| 1087 |
+
if self.linear_layer_for_vis_conf:
|
| 1088 |
+
vis_conf = self.vis_conf_head(tokens)
|
| 1089 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
| 1090 |
+
|
| 1091 |
+
return flow
|
| 1092 |
+
|
| 1093 |
+
class BasicMotionEncoder(nn.Module):
|
| 1094 |
+
def __init__(self, corr_channel, dim=128, pdim=2):
|
| 1095 |
+
super(BasicMotionEncoder, self).__init__()
|
| 1096 |
+
self.pdim = pdim
|
| 1097 |
+
self.convc1 = nn.Conv2d(corr_channel, dim*4, 1, padding=0)
|
| 1098 |
+
self.convc2 = nn.Conv2d(dim*4, dim+dim//2, 3, padding=1)
|
| 1099 |
+
if pdim==2 or pdim==4:
|
| 1100 |
+
self.convf1 = nn.Conv2d(pdim, dim*2, 5, padding=2)
|
| 1101 |
+
self.convf2 = nn.Conv2d(dim*2, dim//2, 3, padding=1)
|
| 1102 |
+
self.conv = nn.Conv2d(dim*2, dim-pdim, 3, padding=1)
|
| 1103 |
+
else:
|
| 1104 |
+
self.conv = nn.Conv2d(dim+dim//2+pdim, dim, 3, padding=1)
|
| 1105 |
+
|
| 1106 |
+
def forward(self, flow, corr):
|
| 1107 |
+
cor = F.relu(self.convc1(corr))
|
| 1108 |
+
cor = F.relu(self.convc2(cor))
|
| 1109 |
+
if self.pdim==2 or self.pdim==4:
|
| 1110 |
+
flo = F.relu(self.convf1(flow))
|
| 1111 |
+
flo = F.relu(self.convf2(flo))
|
| 1112 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 1113 |
+
out = F.relu(self.conv(cor_flo))
|
| 1114 |
+
return torch.cat([out, flow], dim=1)
|
| 1115 |
+
else:
|
| 1116 |
+
# the flow is already encoded to something nice
|
| 1117 |
+
cor_flo = torch.cat([cor, flow], dim=1)
|
| 1118 |
+
return F.relu(self.conv(cor_flo))
|
| 1119 |
+
# return torch.cat([out, flow], dim=1)
|
| 1120 |
+
|
| 1121 |
+
def conv133_encoder(input_dim, dim, expansion_factor=4):
|
| 1122 |
+
return nn.Sequential(
|
| 1123 |
+
nn.Conv2d(input_dim, dim*expansion_factor, kernel_size=1),
|
| 1124 |
+
nn.GELU(),
|
| 1125 |
+
nn.Conv2d(dim*expansion_factor, dim*expansion_factor, kernel_size=3, padding=1),
|
| 1126 |
+
nn.GELU(),
|
| 1127 |
+
nn.Conv2d(dim*expansion_factor, dim, kernel_size=3, padding=1),
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
class BasicUpdateBlock(nn.Module):
|
| 1131 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
|
| 1132 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
| 1133 |
+
super(BasicUpdateBlock, self).__init__()
|
| 1134 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
|
| 1135 |
+
self.compressor = conv1x1(2*cdim+hdim, hdim)
|
| 1136 |
+
|
| 1137 |
+
self.refine = []
|
| 1138 |
+
for i in range(num_blocks):
|
| 1139 |
+
self.refine.append(CNBlock1d(hdim, hdim))
|
| 1140 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
| 1141 |
+
self.refine = nn.ModuleList(self.refine)
|
| 1142 |
+
|
| 1143 |
+
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
|
| 1144 |
+
BS,C,H,W = flowfeat.shape
|
| 1145 |
+
B = BS//S
|
| 1146 |
+
|
| 1147 |
+
# with torch.no_grad():
|
| 1148 |
+
motion_features = self.encoder(flow, corr)
|
| 1149 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
|
| 1150 |
+
|
| 1151 |
+
for blk in self.refine:
|
| 1152 |
+
flowfeat = blk(flowfeat, S)
|
| 1153 |
+
return flowfeat
|
| 1154 |
+
|
| 1155 |
+
class FullUpdateBlock(nn.Module):
|
| 1156 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=2, use_attn=False):
|
| 1157 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
| 1158 |
+
super(FullUpdateBlock, self).__init__()
|
| 1159 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim, pdim=pdim)
|
| 1160 |
+
|
| 1161 |
+
# note we have hdim==cdim
|
| 1162 |
+
# compressor chans:
|
| 1163 |
+
# dim for flowfeat
|
| 1164 |
+
# dim for ctxfeat
|
| 1165 |
+
# dim for motion_features
|
| 1166 |
+
# pdim for flow (if p 2, like if we give sincos(relflow))
|
| 1167 |
+
# 2 for visconf
|
| 1168 |
+
|
| 1169 |
+
if pdim==2:
|
| 1170 |
+
# hdim==cdim
|
| 1171 |
+
# dim for flowfeat
|
| 1172 |
+
# dim for ctxfeat
|
| 1173 |
+
# dim for motion_features
|
| 1174 |
+
# 2 for visconf
|
| 1175 |
+
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
|
| 1176 |
+
else:
|
| 1177 |
+
# we concatenate the flow info again, to not lose it (e.g., from the relu)
|
| 1178 |
+
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
|
| 1179 |
+
|
| 1180 |
+
self.refine = []
|
| 1181 |
+
for i in range(num_blocks):
|
| 1182 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
|
| 1183 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
| 1184 |
+
self.refine = nn.ModuleList(self.refine)
|
| 1185 |
+
|
| 1186 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
| 1187 |
+
BS,C,H,W = flowfeat.shape
|
| 1188 |
+
B = BS//S
|
| 1189 |
+
motion_features = self.encoder(flow, corr)
|
| 1190 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
|
| 1191 |
+
for blk in self.refine:
|
| 1192 |
+
flowfeat = blk(flowfeat, S)
|
| 1193 |
+
return flowfeat
|
| 1194 |
+
|
| 1195 |
+
class MixerUpdateBlock(nn.Module):
|
| 1196 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128):
|
| 1197 |
+
# flowfeat is hdim; ctxfeat is dim. typically hdim==cdim.
|
| 1198 |
+
super(MixerUpdateBlock, self).__init__()
|
| 1199 |
+
self.encoder = BasicMotionEncoder(corr_channel, dim=cdim)
|
| 1200 |
+
self.compressor = conv1x1(2*cdim+hdim, hdim)
|
| 1201 |
+
|
| 1202 |
+
self.refine = []
|
| 1203 |
+
for i in range(num_blocks):
|
| 1204 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_mixer=True))
|
| 1205 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
| 1206 |
+
self.refine = nn.ModuleList(self.refine)
|
| 1207 |
+
|
| 1208 |
+
def forward(self, flowfeat, ctxfeat, corr, flow, S, upsample=True):
|
| 1209 |
+
BS,C,H,W = flowfeat.shape
|
| 1210 |
+
B = BS//S
|
| 1211 |
+
|
| 1212 |
+
# with torch.no_grad():
|
| 1213 |
+
motion_features = self.encoder(flow, corr)
|
| 1214 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features], dim=1))
|
| 1215 |
+
|
| 1216 |
+
for ii, blk in enumerate(self.refine):
|
| 1217 |
+
flowfeat = blk(flowfeat, S)
|
| 1218 |
+
return flowfeat
|
| 1219 |
+
|
| 1220 |
+
class FacUpdateBlock(nn.Module):
|
| 1221 |
+
def __init__(self, corr_channel, num_blocks, hdim=128, cdim=128, pdim=84, use_attn=False):
|
| 1222 |
+
super(FacUpdateBlock, self).__init__()
|
| 1223 |
+
self.corr_encoder = conv133_encoder(corr_channel, cdim)
|
| 1224 |
+
# note we have hdim==cdim
|
| 1225 |
+
# compressor chans:
|
| 1226 |
+
# dim for flowfeat
|
| 1227 |
+
# dim for ctxfeat
|
| 1228 |
+
# dim for corr
|
| 1229 |
+
# pdim for flow
|
| 1230 |
+
# 2 for visconf
|
| 1231 |
+
self.compressor = conv1x1(2*cdim+hdim+2+pdim, hdim)
|
| 1232 |
+
self.refine = []
|
| 1233 |
+
for i in range(num_blocks):
|
| 1234 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn))
|
| 1235 |
+
self.refine.append(CNBlock2d(hdim, hdim))
|
| 1236 |
+
self.refine = nn.ModuleList(self.refine)
|
| 1237 |
+
|
| 1238 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
| 1239 |
+
BS,C,H,W = flowfeat.shape
|
| 1240 |
+
B = BS//S
|
| 1241 |
+
corr = self.corr_encoder(corr)
|
| 1242 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corr, visconf, flow], dim=1))
|
| 1243 |
+
for blk in self.refine:
|
| 1244 |
+
flowfeat = blk(flowfeat, S)
|
| 1245 |
+
return flowfeat
|
| 1246 |
+
|
| 1247 |
+
class CleanUpdateBlock(nn.Module):
|
| 1248 |
+
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=256, pdim=84, use_attn=False, use_layer_scale=True):
|
| 1249 |
+
super(CleanUpdateBlock, self).__init__()
|
| 1250 |
+
self.corr_encoder = conv133_encoder(corr_channel, cdim)
|
| 1251 |
+
# compressor chans:
|
| 1252 |
+
# cdim for flowfeat
|
| 1253 |
+
# cdim for ctxfeat
|
| 1254 |
+
# cdim for corrfeat
|
| 1255 |
+
# pdim for flow
|
| 1256 |
+
# 2 for visconf
|
| 1257 |
+
self.compressor = conv1x1(3*cdim+pdim+2, hdim)
|
| 1258 |
+
self.refine = []
|
| 1259 |
+
for i in range(num_blocks):
|
| 1260 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_layer_scale=use_layer_scale))
|
| 1261 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
| 1262 |
+
self.refine = nn.ModuleList(self.refine)
|
| 1263 |
+
self.final_conv = conv1x1(hdim, cdim)
|
| 1264 |
+
|
| 1265 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
| 1266 |
+
BS,C,H,W = flowfeat.shape
|
| 1267 |
+
B = BS//S
|
| 1268 |
+
corrfeat = self.corr_encoder(corr)
|
| 1269 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, corrfeat, flow, visconf], dim=1))
|
| 1270 |
+
for blk in self.refine:
|
| 1271 |
+
flowfeat = blk(flowfeat, S)
|
| 1272 |
+
flowfeat = self.final_conv(flowfeat)
|
| 1273 |
+
return flowfeat
|
| 1274 |
+
|
| 1275 |
+
class RelUpdateBlock(nn.Module):
|
| 1276 |
+
def __init__(self, corr_channel, num_blocks, cdim=128, hdim=128, pdim=4, use_attn=True, use_mixer=False, use_conv=False, use_convb=False, use_layer_scale=True, no_time=False, no_space=False, no_ctx=False):
|
| 1277 |
+
super(RelUpdateBlock, self).__init__()
|
| 1278 |
+
self.motion_encoder = BasicMotionEncoder(corr_channel, dim=hdim, pdim=pdim) # B,hdim,H,W
|
| 1279 |
+
self.no_ctx = no_ctx
|
| 1280 |
+
if no_ctx:
|
| 1281 |
+
self.compressor = conv1x1(cdim+hdim+2, hdim)
|
| 1282 |
+
else:
|
| 1283 |
+
self.compressor = conv1x1(2*cdim+hdim+2, hdim)
|
| 1284 |
+
self.refine = []
|
| 1285 |
+
for i in range(num_blocks):
|
| 1286 |
+
if not no_time:
|
| 1287 |
+
self.refine.append(CNBlock1d(hdim, hdim, use_attn=use_attn, use_mixer=use_mixer, use_conv=use_conv, use_convb=use_convb, use_layer_scale=use_layer_scale))
|
| 1288 |
+
if not no_space:
|
| 1289 |
+
self.refine.append(CNBlock2d(hdim, hdim, use_layer_scale=use_layer_scale))
|
| 1290 |
+
self.refine = nn.ModuleList(self.refine)
|
| 1291 |
+
self.final_conv = conv1x1(hdim, cdim)
|
| 1292 |
+
|
| 1293 |
+
def forward(self, flowfeat, ctxfeat, visconf, corr, flow, S, upsample=True):
|
| 1294 |
+
BS,C,H,W = flowfeat.shape
|
| 1295 |
+
B = BS//S
|
| 1296 |
+
motion_features = self.motion_encoder(flow, corr)
|
| 1297 |
+
if self.no_ctx:
|
| 1298 |
+
flowfeat = self.compressor(torch.cat([flowfeat, motion_features, visconf], dim=1))
|
| 1299 |
+
else:
|
| 1300 |
+
flowfeat = self.compressor(torch.cat([flowfeat, ctxfeat, motion_features, visconf], dim=1))
|
| 1301 |
+
for blk in self.refine:
|
| 1302 |
+
flowfeat = blk(flowfeat, S)
|
| 1303 |
+
flowfeat = self.final_conv(flowfeat)
|
| 1304 |
+
return flowfeat
|
utils/basic.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
EPS = 1e-6
|
| 5 |
+
|
| 6 |
+
def sub2ind(height, width, y, x):
|
| 7 |
+
return y*width + x
|
| 8 |
+
|
| 9 |
+
def ind2sub(height, width, ind):
|
| 10 |
+
y = ind // width
|
| 11 |
+
x = ind % width
|
| 12 |
+
return y, x
|
| 13 |
+
|
| 14 |
+
def get_lr_str(lr):
|
| 15 |
+
lrn = "%.1e" % lr # e.g., 5.0e-04
|
| 16 |
+
lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
|
| 17 |
+
return lrn
|
| 18 |
+
|
| 19 |
+
def strnum(x):
|
| 20 |
+
s = '%g' % x
|
| 21 |
+
if '.' in s:
|
| 22 |
+
if x < 1.0:
|
| 23 |
+
s = s[s.index('.'):]
|
| 24 |
+
s = s[:min(len(s),4)]
|
| 25 |
+
return s
|
| 26 |
+
|
| 27 |
+
def assert_same_shape(t1, t2):
|
| 28 |
+
for (x, y) in zip(list(t1.shape), list(t2.shape)):
|
| 29 |
+
assert(x==y)
|
| 30 |
+
|
| 31 |
+
def mkdir(path):
|
| 32 |
+
if not os.path.exists(path):
|
| 33 |
+
os.makedirs(path)
|
| 34 |
+
|
| 35 |
+
def print_stats(name, tensor):
|
| 36 |
+
shape = tensor.shape
|
| 37 |
+
tensor = tensor.detach().cpu().numpy()
|
| 38 |
+
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
|
| 39 |
+
|
| 40 |
+
def normalize_single(d):
|
| 41 |
+
# d is a whatever shape torch tensor
|
| 42 |
+
dmin = torch.min(d)
|
| 43 |
+
dmax = torch.max(d)
|
| 44 |
+
d = (d-dmin)/(EPS+(dmax-dmin))
|
| 45 |
+
return d
|
| 46 |
+
|
| 47 |
+
def normalize(d):
|
| 48 |
+
# d is B x whatever. normalize within each element of the batch
|
| 49 |
+
out = torch.zeros(d.size(), dtype=d.dtype, device=d.device)
|
| 50 |
+
B = list(d.size())[0]
|
| 51 |
+
for b in list(range(B)):
|
| 52 |
+
out[b] = normalize_single(d[b])
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
|
| 56 |
+
# returns a meshgrid sized B x Y x X
|
| 57 |
+
|
| 58 |
+
grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
|
| 59 |
+
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
| 60 |
+
grid_y = grid_y.repeat(B, 1, X)
|
| 61 |
+
|
| 62 |
+
grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
|
| 63 |
+
grid_x = torch.reshape(grid_x, [1, 1, X])
|
| 64 |
+
grid_x = grid_x.repeat(B, Y, 1)
|
| 65 |
+
|
| 66 |
+
if norm:
|
| 67 |
+
grid_y, grid_x = normalize_grid2d(
|
| 68 |
+
grid_y, grid_x, Y, X)
|
| 69 |
+
|
| 70 |
+
if stack:
|
| 71 |
+
# note we stack in xy order
|
| 72 |
+
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
| 73 |
+
if on_chans:
|
| 74 |
+
grid = torch.stack([grid_x, grid_y], dim=1)
|
| 75 |
+
else:
|
| 76 |
+
grid = torch.stack([grid_x, grid_y], dim=-1)
|
| 77 |
+
return grid
|
| 78 |
+
else:
|
| 79 |
+
return grid_y, grid_x
|
| 80 |
+
|
| 81 |
+
def gridcloud2d(B, Y, X, norm=False, device='cuda'):
|
| 82 |
+
# we want to sample for each location in the grid
|
| 83 |
+
grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
|
| 84 |
+
x = torch.reshape(grid_x, [B, -1])
|
| 85 |
+
y = torch.reshape(grid_y, [B, -1])
|
| 86 |
+
# these are B x N
|
| 87 |
+
xy = torch.stack([x, y], dim=2)
|
| 88 |
+
# this is B x N x 2
|
| 89 |
+
return xy
|
| 90 |
+
|
| 91 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False):
|
| 92 |
+
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
| 93 |
+
# returns shape-1
|
| 94 |
+
# axis can be a list of axes
|
| 95 |
+
if not broadcast:
|
| 96 |
+
for (a,b) in zip(x.size(), mask.size()):
|
| 97 |
+
if not a==b:
|
| 98 |
+
print('some shape mismatch:', x.shape, mask.shape)
|
| 99 |
+
assert(a==b) # some shape mismatch!
|
| 100 |
+
# assert(x.size() == mask.size())
|
| 101 |
+
prod = x*mask
|
| 102 |
+
if dim is None:
|
| 103 |
+
numer = torch.sum(prod)
|
| 104 |
+
denom = EPS+torch.sum(mask)
|
| 105 |
+
else:
|
| 106 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
| 107 |
+
denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
|
| 108 |
+
mean = numer/denom
|
| 109 |
+
return mean
|
| 110 |
+
|
| 111 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
| 112 |
+
# x and mask are the same shape
|
| 113 |
+
assert(x.size() == mask.size())
|
| 114 |
+
device = x.device
|
| 115 |
+
|
| 116 |
+
B = list(x.shape)[0]
|
| 117 |
+
x = x.detach().cpu().numpy()
|
| 118 |
+
mask = mask.detach().cpu().numpy()
|
| 119 |
+
|
| 120 |
+
if keep_batch:
|
| 121 |
+
x = np.reshape(x, [B, -1])
|
| 122 |
+
mask = np.reshape(mask, [B, -1])
|
| 123 |
+
meds = np.zeros([B], np.float32)
|
| 124 |
+
for b in list(range(B)):
|
| 125 |
+
xb = x[b]
|
| 126 |
+
mb = mask[b]
|
| 127 |
+
if np.sum(mb) > 0:
|
| 128 |
+
xb = xb[mb > 0]
|
| 129 |
+
meds[b] = np.median(xb)
|
| 130 |
+
else:
|
| 131 |
+
meds[b] = np.nan
|
| 132 |
+
meds = torch.from_numpy(meds).to(device)
|
| 133 |
+
return meds.float()
|
| 134 |
+
else:
|
| 135 |
+
x = np.reshape(x, [-1])
|
| 136 |
+
mask = np.reshape(mask, [-1])
|
| 137 |
+
if np.sum(mask) > 0:
|
| 138 |
+
x = x[mask > 0]
|
| 139 |
+
med = np.median(x)
|
| 140 |
+
else:
|
| 141 |
+
med = np.nan
|
| 142 |
+
med = np.array([med], np.float32)
|
| 143 |
+
med = torch.from_numpy(med).to(device)
|
| 144 |
+
return med.float()
|
utils/data.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import dataclasses
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Optional, Dict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(eq=False)
|
| 9 |
+
class VideoData:
|
| 10 |
+
"""
|
| 11 |
+
Dataclass for storing video tracks data.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
video: torch.Tensor # B,S,C,H,W
|
| 15 |
+
trajs: torch.Tensor # B,S,N,2
|
| 16 |
+
visibs: torch.Tensor # B,S,N
|
| 17 |
+
valids: Optional[torch.Tensor] = None # B,S,N
|
| 18 |
+
seq_name: Optional[str] = None
|
| 19 |
+
dname: Optional[str] = None
|
| 20 |
+
aug_video: Optional[torch.Tensor] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def collate_fn(batch):
|
| 24 |
+
"""
|
| 25 |
+
Collate function for video tracks data.
|
| 26 |
+
"""
|
| 27 |
+
video = torch.stack([b.video for b in batch], dim=0)
|
| 28 |
+
trajs = torch.stack([b.trajs for b in batch], dim=0)
|
| 29 |
+
visibs = torch.stack([b.visibs for b in batch], dim=0)
|
| 30 |
+
seq_name = [b.seq_name for b in batch]
|
| 31 |
+
dname = [b.dname for b in batch]
|
| 32 |
+
|
| 33 |
+
return VideoData(
|
| 34 |
+
video=video,
|
| 35 |
+
trajs=trajs,
|
| 36 |
+
visibs=visibs,
|
| 37 |
+
seq_name=seq_name,
|
| 38 |
+
dname=dname,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def collate_fn_train(batch):
|
| 43 |
+
"""
|
| 44 |
+
Collate function for video tracks data during training.
|
| 45 |
+
"""
|
| 46 |
+
gotit = [gotit for _, gotit in batch]
|
| 47 |
+
video = torch.stack([b.video for b, _ in batch], dim=0)
|
| 48 |
+
trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
|
| 49 |
+
visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
|
| 50 |
+
valids = torch.stack([b.valids for b, _ in batch], dim=0)
|
| 51 |
+
seq_name = [b.seq_name for b, _ in batch]
|
| 52 |
+
dname = [b.dname for b, _ in batch]
|
| 53 |
+
|
| 54 |
+
return (
|
| 55 |
+
VideoData(
|
| 56 |
+
video=video,
|
| 57 |
+
trajs=trajs,
|
| 58 |
+
visibs=visibs,
|
| 59 |
+
valids=valids,
|
| 60 |
+
seq_name=seq_name,
|
| 61 |
+
dname=dname,
|
| 62 |
+
),
|
| 63 |
+
gotit,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def try_to_cuda(t: Any) -> Any:
|
| 68 |
+
"""
|
| 69 |
+
Try to move the input variable `t` to a cuda device.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
t: Input.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
t_cuda: `t` moved to a cuda device, if supported.
|
| 76 |
+
"""
|
| 77 |
+
try:
|
| 78 |
+
t = t.float().cuda()
|
| 79 |
+
except AttributeError:
|
| 80 |
+
pass
|
| 81 |
+
return t
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def dataclass_to_cuda_(obj):
|
| 85 |
+
"""
|
| 86 |
+
Move all contents of a dataclass to cuda inplace if supported.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
batch: Input dataclass.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
batch_cuda: `batch` moved to a cuda device, if supported.
|
| 93 |
+
"""
|
| 94 |
+
for f in dataclasses.fields(obj):
|
| 95 |
+
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
| 96 |
+
return obj
|
utils/improc.py
ADDED
|
@@ -0,0 +1,1103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import utils.basic
|
| 4 |
+
import utils.py
|
| 5 |
+
from sklearn.decomposition import PCA
|
| 6 |
+
from matplotlib import cm
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import cv2
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
EPS = 1e-6
|
| 11 |
+
|
| 12 |
+
from skimage.color import (
|
| 13 |
+
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
|
| 14 |
+
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
|
| 15 |
+
|
| 16 |
+
def _convert(input_, type_):
|
| 17 |
+
return {
|
| 18 |
+
'float': input_.float(),
|
| 19 |
+
'double': input_.double(),
|
| 20 |
+
}.get(type_, input_)
|
| 21 |
+
|
| 22 |
+
def _generic_transform_sk_3d(transform, in_type='', out_type=''):
|
| 23 |
+
def apply_transform_individual(input_):
|
| 24 |
+
device = input_.device
|
| 25 |
+
input_ = input_.cpu()
|
| 26 |
+
input_ = _convert(input_, in_type)
|
| 27 |
+
|
| 28 |
+
input_ = input_.permute(1, 2, 0).detach().numpy()
|
| 29 |
+
transformed = transform(input_)
|
| 30 |
+
output = torch.from_numpy(transformed).float().permute(2, 0, 1)
|
| 31 |
+
output = _convert(output, out_type)
|
| 32 |
+
return output.to(device)
|
| 33 |
+
|
| 34 |
+
def apply_transform(input_):
|
| 35 |
+
to_stack = []
|
| 36 |
+
for image in input_:
|
| 37 |
+
to_stack.append(apply_transform_individual(image))
|
| 38 |
+
return torch.stack(to_stack)
|
| 39 |
+
return apply_transform
|
| 40 |
+
|
| 41 |
+
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
|
| 42 |
+
|
| 43 |
+
def flow2color(flow, clip=0.0):
|
| 44 |
+
B, C, H, W = list(flow.size())
|
| 45 |
+
assert(C==2)
|
| 46 |
+
flow = flow[0:1].detach()
|
| 47 |
+
if clip==0:
|
| 48 |
+
clip = torch.max(torch.abs(flow)).item()
|
| 49 |
+
flow = torch.clamp(flow, -clip, clip)/clip
|
| 50 |
+
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) # B,1,H,W
|
| 51 |
+
radius_clipped = torch.clamp(radius, 0.0, 1.0)
|
| 52 |
+
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B,1,H,W
|
| 53 |
+
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
|
| 54 |
+
saturation = torch.ones_like(hue) * 0.75
|
| 55 |
+
value = radius_clipped
|
| 56 |
+
hsv = torch.cat([hue, saturation, value], dim=1) # B,3,H,W
|
| 57 |
+
flow = hsv_to_rgb(hsv)
|
| 58 |
+
flow = (flow*255.0).type(torch.ByteTensor)
|
| 59 |
+
return flow
|
| 60 |
+
|
| 61 |
+
COLORMAP_FILE = "./utils/bremm.png"
|
| 62 |
+
class ColorMap2d:
|
| 63 |
+
def __init__(self, filename=None):
|
| 64 |
+
self._colormap_file = filename or COLORMAP_FILE
|
| 65 |
+
self._img = (plt.imread(self._colormap_file)*255).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
self._height = self._img.shape[0]
|
| 68 |
+
self._width = self._img.shape[1]
|
| 69 |
+
|
| 70 |
+
def __call__(self, X):
|
| 71 |
+
assert len(X.shape) == 2
|
| 72 |
+
output = np.zeros((X.shape[0], 3), dtype=np.uint8)
|
| 73 |
+
for i in range(X.shape[0]):
|
| 74 |
+
x, y = X[i, :]
|
| 75 |
+
xp = int((self._width-1) * x)
|
| 76 |
+
yp = int((self._height-1) * y)
|
| 77 |
+
xp = np.clip(xp, 0, self._width-1)
|
| 78 |
+
yp = np.clip(yp, 0, self._height-1)
|
| 79 |
+
output[i, :] = self._img[yp, xp]
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
def get_2d_colors(xys, H, W):
|
| 83 |
+
N,D = xys.shape
|
| 84 |
+
assert(D==2)
|
| 85 |
+
bremm = ColorMap2d()
|
| 86 |
+
xys[:,0] /= float(W-1)
|
| 87 |
+
xys[:,1] /= float(H-1)
|
| 88 |
+
colors = bremm(xys)
|
| 89 |
+
# print('colors', colors)
|
| 90 |
+
# colors = (colors[0]*255).astype(np.uint8)
|
| 91 |
+
# colors = (int(colors[0]),int(colors[1]),int(colors[2]))
|
| 92 |
+
return colors
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_n_colors(N, sequential=False):
|
| 96 |
+
label_colors = []
|
| 97 |
+
for ii in range(N):
|
| 98 |
+
if sequential:
|
| 99 |
+
rgb = cm.winter(ii/(N-1))
|
| 100 |
+
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
|
| 101 |
+
else:
|
| 102 |
+
rgb = np.zeros(3)
|
| 103 |
+
while np.sum(rgb) < 128: # ensure min brightness
|
| 104 |
+
rgb = np.random.randint(0,256,3)
|
| 105 |
+
label_colors.append(rgb)
|
| 106 |
+
return label_colors
|
| 107 |
+
|
| 108 |
+
def pca_embed(emb, keep, valid=None):
|
| 109 |
+
# helper function for reduce_emb
|
| 110 |
+
# emb is B,C,H,W
|
| 111 |
+
# keep is the number of principal components to keep
|
| 112 |
+
emb = emb + EPS
|
| 113 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
|
| 114 |
+
|
| 115 |
+
if valid:
|
| 116 |
+
valid = valid.cpu().detach().numpy().reshape((H*W))
|
| 117 |
+
|
| 118 |
+
emb_reduced = list()
|
| 119 |
+
|
| 120 |
+
B, H, W, C = np.shape(emb)
|
| 121 |
+
for img in emb:
|
| 122 |
+
if np.isnan(img).any():
|
| 123 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
pixels_kd = np.reshape(img, (H*W, C))
|
| 127 |
+
|
| 128 |
+
if valid:
|
| 129 |
+
pixels_kd_pca = pixels_kd[valid]
|
| 130 |
+
else:
|
| 131 |
+
pixels_kd_pca = pixels_kd
|
| 132 |
+
|
| 133 |
+
P = PCA(keep)
|
| 134 |
+
P.fit(pixels_kd_pca)
|
| 135 |
+
|
| 136 |
+
if valid:
|
| 137 |
+
pixels3d = P.transform(pixels_kd)*valid
|
| 138 |
+
else:
|
| 139 |
+
pixels3d = P.transform(pixels_kd)
|
| 140 |
+
|
| 141 |
+
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
|
| 142 |
+
if np.isnan(out_img).any():
|
| 143 |
+
emb_reduced.append(np.zeros([H, W, keep]))
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
emb_reduced.append(out_img)
|
| 147 |
+
|
| 148 |
+
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
|
| 149 |
+
|
| 150 |
+
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
|
| 151 |
+
|
| 152 |
+
def pca_embed_together(emb, keep):
|
| 153 |
+
# emb is B,C,H,W
|
| 154 |
+
# keep is the number of principal components to keep
|
| 155 |
+
emb = emb + EPS
|
| 156 |
+
emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
|
| 157 |
+
|
| 158 |
+
B, H, W, C = np.shape(emb)
|
| 159 |
+
if np.isnan(emb).any():
|
| 160 |
+
return torch.zeros(B, keep, H, W)
|
| 161 |
+
|
| 162 |
+
pixelskd = np.reshape(emb, (B*H*W, C))
|
| 163 |
+
P = PCA(keep)
|
| 164 |
+
P.fit(pixelskd)
|
| 165 |
+
pixels3d = P.transform(pixelskd)
|
| 166 |
+
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
|
| 167 |
+
|
| 168 |
+
if np.isnan(out_img).any():
|
| 169 |
+
return torch.zeros(B, keep, H, W)
|
| 170 |
+
|
| 171 |
+
return torch.from_numpy(out_img).permute(0, 3, 1, 2)
|
| 172 |
+
|
| 173 |
+
def reduce_emb(emb, valid=None, inbound=None, together=False):
|
| 174 |
+
S, C, H, W = list(emb.size())
|
| 175 |
+
keep = 4
|
| 176 |
+
|
| 177 |
+
if together:
|
| 178 |
+
reduced_emb = pca_embed_together(emb, keep)
|
| 179 |
+
else:
|
| 180 |
+
reduced_emb = pca_embed(emb, keep, valid) #not im
|
| 181 |
+
|
| 182 |
+
reduced_emb = reduced_emb[:,1:]
|
| 183 |
+
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
|
| 184 |
+
if inbound is not None:
|
| 185 |
+
emb_inbound = emb*inbound
|
| 186 |
+
else:
|
| 187 |
+
emb_inbound = None
|
| 188 |
+
|
| 189 |
+
return reduced_emb, emb_inbound
|
| 190 |
+
|
| 191 |
+
def get_feat_pca(feat, valid=None):
|
| 192 |
+
B, C, D, W = list(feat.size())
|
| 193 |
+
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
|
| 194 |
+
return pca
|
| 195 |
+
|
| 196 |
+
def gif_and_tile(ims, just_gif=False):
|
| 197 |
+
S = len(ims)
|
| 198 |
+
# each im is B x H x W x C
|
| 199 |
+
# i want a gif in the left, and the tiled frames on the right
|
| 200 |
+
# for the gif tool, this means making a B x S x H x W tensor
|
| 201 |
+
# where the leftmost part is sequential and the rest is tiled
|
| 202 |
+
gif = torch.stack(ims, dim=1)
|
| 203 |
+
if just_gif:
|
| 204 |
+
return gif
|
| 205 |
+
til = torch.cat(ims, dim=2)
|
| 206 |
+
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
|
| 207 |
+
im = torch.cat([gif, til], dim=3)
|
| 208 |
+
return im
|
| 209 |
+
|
| 210 |
+
def preprocess_color(x):
|
| 211 |
+
if isinstance(x, np.ndarray):
|
| 212 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
| 213 |
+
else:
|
| 214 |
+
return x.float() * 1./255 - 0.5
|
| 215 |
+
|
| 216 |
+
def back2color(i, blacken_zeros=False):
|
| 217 |
+
if blacken_zeros:
|
| 218 |
+
const = torch.tensor([-0.5])
|
| 219 |
+
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
|
| 220 |
+
return back2color(i)
|
| 221 |
+
else:
|
| 222 |
+
return ((i+0.5)*255).type(torch.ByteTensor)
|
| 223 |
+
|
| 224 |
+
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
|
| 225 |
+
|
| 226 |
+
rgb = vis.detach().cpu().numpy()[0]
|
| 227 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
| 228 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
| 229 |
+
color = (255, 255, 255)
|
| 230 |
+
# print('putting frame id', frame_id)
|
| 231 |
+
|
| 232 |
+
frame_str = utils.basic.strnum(frame_id)
|
| 233 |
+
|
| 234 |
+
text_color_bg = (0,0,0)
|
| 235 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 236 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
| 237 |
+
text_w, text_h = text_size
|
| 238 |
+
if shadow:
|
| 239 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
| 240 |
+
|
| 241 |
+
cv2.putText(
|
| 242 |
+
rgb,
|
| 243 |
+
frame_str,
|
| 244 |
+
(left, top), # from left, from top
|
| 245 |
+
font,
|
| 246 |
+
scale, # font scale (float)
|
| 247 |
+
color,
|
| 248 |
+
1) # font thickness (int)
|
| 249 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
| 250 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
| 251 |
+
return vis
|
| 252 |
+
|
| 253 |
+
def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
|
| 254 |
+
|
| 255 |
+
rgb = vis.detach().cpu().numpy()[0]
|
| 256 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
| 257 |
+
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
|
| 258 |
+
color = (255, 255, 255)
|
| 259 |
+
|
| 260 |
+
text_color_bg = (0,0,0)
|
| 261 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 262 |
+
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
|
| 263 |
+
text_w, text_h = text_size
|
| 264 |
+
if shadow:
|
| 265 |
+
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
|
| 266 |
+
|
| 267 |
+
cv2.putText(
|
| 268 |
+
rgb,
|
| 269 |
+
frame_str,
|
| 270 |
+
(left, top), # from left, from top
|
| 271 |
+
font,
|
| 272 |
+
scale, # font scale (float)
|
| 273 |
+
color,
|
| 274 |
+
1) # font thickness (int)
|
| 275 |
+
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
| 276 |
+
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
| 277 |
+
return vis
|
| 278 |
+
|
| 279 |
+
class Summ_writer(object):
|
| 280 |
+
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
|
| 281 |
+
self.writer = writer
|
| 282 |
+
self.global_step = global_step
|
| 283 |
+
self.log_freq = log_freq
|
| 284 |
+
self.scalar_freq = scalar_freq
|
| 285 |
+
self.fps = fps
|
| 286 |
+
self.just_gif = just_gif
|
| 287 |
+
self.maxwidth = 10000
|
| 288 |
+
self.save_this = (self.global_step % self.log_freq == 0)
|
| 289 |
+
self.scalar_freq = max(scalar_freq,1)
|
| 290 |
+
self.save_scalar = (self.global_step % self.scalar_freq == 0)
|
| 291 |
+
if self.save_this:
|
| 292 |
+
self.save_scalar = True
|
| 293 |
+
|
| 294 |
+
def summ_gif(self, name, tensor, blacken_zeros=False):
|
| 295 |
+
# tensor should be in B x S x C x H x W
|
| 296 |
+
|
| 297 |
+
assert tensor.dtype in {torch.uint8,torch.float32}
|
| 298 |
+
shape = list(tensor.shape)
|
| 299 |
+
|
| 300 |
+
if tensor.dtype == torch.float32:
|
| 301 |
+
tensor = back2color(tensor, blacken_zeros=blacken_zeros)
|
| 302 |
+
|
| 303 |
+
video_to_write = tensor[0:1]
|
| 304 |
+
|
| 305 |
+
S = video_to_write.shape[1]
|
| 306 |
+
if S==1:
|
| 307 |
+
# video_to_write is 1 x 1 x C x H x W
|
| 308 |
+
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
|
| 309 |
+
else:
|
| 310 |
+
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
|
| 311 |
+
|
| 312 |
+
return video_to_write
|
| 313 |
+
|
| 314 |
+
def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
|
| 315 |
+
if self.save_this:
|
| 316 |
+
|
| 317 |
+
ims = gif_and_tile(ims, just_gif=self.just_gif)
|
| 318 |
+
vis = ims
|
| 319 |
+
|
| 320 |
+
assert vis.dtype in {torch.uint8,torch.float32}
|
| 321 |
+
|
| 322 |
+
if vis.dtype == torch.float32:
|
| 323 |
+
vis = back2color(vis, blacken_zeros)
|
| 324 |
+
|
| 325 |
+
B, S, C, H, W = list(vis.shape)
|
| 326 |
+
|
| 327 |
+
if frame_ids is not None:
|
| 328 |
+
assert(len(frame_ids)==S)
|
| 329 |
+
for s in range(S):
|
| 330 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
| 331 |
+
|
| 332 |
+
if frame_strs is not None:
|
| 333 |
+
assert(len(frame_strs)==S)
|
| 334 |
+
for s in range(S):
|
| 335 |
+
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
|
| 336 |
+
|
| 337 |
+
if int(W) > self.maxwidth:
|
| 338 |
+
vis = vis[:,:,:,:self.maxwidth]
|
| 339 |
+
|
| 340 |
+
if only_return:
|
| 341 |
+
return vis
|
| 342 |
+
else:
|
| 343 |
+
return self.summ_gif(name, vis, blacken_zeros)
|
| 344 |
+
|
| 345 |
+
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
|
| 346 |
+
if self.save_this:
|
| 347 |
+
assert ims.dtype in {torch.uint8,torch.float32}
|
| 348 |
+
|
| 349 |
+
if ims.dtype == torch.float32:
|
| 350 |
+
ims = back2color(ims, blacken_zeros)
|
| 351 |
+
|
| 352 |
+
#ims is B x C x H x W
|
| 353 |
+
vis = ims[0:1] # just the first one
|
| 354 |
+
B, C, H, W = list(vis.shape)
|
| 355 |
+
|
| 356 |
+
if halfres:
|
| 357 |
+
vis = F.interpolate(vis, scale_factor=0.5)
|
| 358 |
+
|
| 359 |
+
if frame_id is not None:
|
| 360 |
+
vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
|
| 361 |
+
|
| 362 |
+
if frame_str is not None:
|
| 363 |
+
vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
|
| 364 |
+
|
| 365 |
+
if int(W) > self.maxwidth:
|
| 366 |
+
vis = vis[:,:,:,:self.maxwidth]
|
| 367 |
+
|
| 368 |
+
if only_return:
|
| 369 |
+
return vis
|
| 370 |
+
else:
|
| 371 |
+
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
|
| 372 |
+
|
| 373 |
+
def flow2color(self, flow, clip=0.0):
|
| 374 |
+
B, C, H, W = list(flow.size())
|
| 375 |
+
assert(C==2)
|
| 376 |
+
flow = flow[0:1].detach()
|
| 377 |
+
|
| 378 |
+
if False:
|
| 379 |
+
flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
|
| 380 |
+
if clip > 0:
|
| 381 |
+
clip_flow = clip
|
| 382 |
+
else:
|
| 383 |
+
clip_flow = None
|
| 384 |
+
im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
|
| 385 |
+
# im = utils.py.flow_to_image(flow, convert_to_bgr=True)
|
| 386 |
+
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
|
| 387 |
+
im = torch.flip(im, dims=[1]).clone() # BGR
|
| 388 |
+
|
| 389 |
+
# # i prefer black bkg
|
| 390 |
+
# white_pixels = (im == 255).all(dim=1, keepdim=True)
|
| 391 |
+
# im[white_pixels.expand(-1, 3, -1, -1)] = 0
|
| 392 |
+
|
| 393 |
+
return im
|
| 394 |
+
|
| 395 |
+
# flow_abs = torch.abs(flow)
|
| 396 |
+
# flow_mean = flow_abs.mean(dim=[1,2,3])
|
| 397 |
+
# flow_std = flow_abs.std(dim=[1,2,3])
|
| 398 |
+
if clip==0:
|
| 399 |
+
clip = torch.max(torch.abs(flow)).item()
|
| 400 |
+
|
| 401 |
+
# if clip:
|
| 402 |
+
flow = torch.clamp(flow, -clip, clip)/clip
|
| 403 |
+
# else:
|
| 404 |
+
# # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
|
| 405 |
+
# # flow_max = flow_mean + flow_std*2 + 1e-10
|
| 406 |
+
# # for b in range(B):
|
| 407 |
+
# # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
|
| 408 |
+
|
| 409 |
+
# flow_max = torch.max(flow_abs[b])
|
| 410 |
+
# for b in range(B):
|
| 411 |
+
# flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
|
| 415 |
+
radius_clipped = torch.clamp(radius, 0.0, 1.0)
|
| 416 |
+
|
| 417 |
+
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
|
| 418 |
+
|
| 419 |
+
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
|
| 420 |
+
# hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
|
| 421 |
+
|
| 422 |
+
saturation = torch.ones_like(hue) * 0.75
|
| 423 |
+
value = radius_clipped
|
| 424 |
+
hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
|
| 425 |
+
|
| 426 |
+
#flow = tf.image.hsv_to_rgb(hsv)
|
| 427 |
+
flow = hsv_to_rgb(hsv)
|
| 428 |
+
flow = (flow*255.0).type(torch.ByteTensor)
|
| 429 |
+
# flow = torch.flip(flow, dims=[1]).clone() # BGR
|
| 430 |
+
return flow
|
| 431 |
+
|
| 432 |
+
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
|
| 433 |
+
# flow is B x C x D x W
|
| 434 |
+
if self.save_this:
|
| 435 |
+
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
|
| 436 |
+
else:
|
| 437 |
+
return None
|
| 438 |
+
|
| 439 |
+
def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
|
| 440 |
+
if self.save_this:
|
| 441 |
+
if bev:
|
| 442 |
+
B, C, H, _, W = list(ims[0].shape)
|
| 443 |
+
if reduce_max:
|
| 444 |
+
ims = [torch.max(im, dim=3)[0] for im in ims]
|
| 445 |
+
else:
|
| 446 |
+
ims = [torch.mean(im, dim=3) for im in ims]
|
| 447 |
+
elif fro:
|
| 448 |
+
B, C, _, H, W = list(ims[0].shape)
|
| 449 |
+
if reduce_max:
|
| 450 |
+
ims = [torch.max(im, dim=2)[0] for im in ims]
|
| 451 |
+
else:
|
| 452 |
+
ims = [torch.mean(im, dim=2) for im in ims]
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
if len(ims) != 1: # sequence
|
| 456 |
+
im = gif_and_tile(ims, just_gif=self.just_gif)
|
| 457 |
+
else:
|
| 458 |
+
im = torch.stack(ims, dim=1) # single frame
|
| 459 |
+
|
| 460 |
+
B, S, C, H, W = list(im.shape)
|
| 461 |
+
|
| 462 |
+
if logvis and max_val:
|
| 463 |
+
max_val = np.log(max_val)
|
| 464 |
+
im = torch.log(torch.clamp(im, 0)+1.0)
|
| 465 |
+
im = torch.clamp(im, 0, max_val)
|
| 466 |
+
im = im/max_val
|
| 467 |
+
norm = False
|
| 468 |
+
elif max_val:
|
| 469 |
+
im = torch.clamp(im, 0, max_val)
|
| 470 |
+
im = im/max_val
|
| 471 |
+
norm = False
|
| 472 |
+
|
| 473 |
+
if norm:
|
| 474 |
+
# normalize before oned2inferno,
|
| 475 |
+
# so that the ranges are similar within B across S
|
| 476 |
+
im = utils.basic.normalize(im)
|
| 477 |
+
|
| 478 |
+
im = im.view(B*S, C, H, W)
|
| 479 |
+
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
|
| 480 |
+
vis = vis.view(B, S, 3, H, W)
|
| 481 |
+
|
| 482 |
+
if frame_ids is not None:
|
| 483 |
+
assert(len(frame_ids)==S)
|
| 484 |
+
for s in range(S):
|
| 485 |
+
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
|
| 486 |
+
|
| 487 |
+
if frame_strs is not None:
|
| 488 |
+
assert(len(frame_strs)==S)
|
| 489 |
+
for s in range(S):
|
| 490 |
+
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
|
| 491 |
+
|
| 492 |
+
if W > self.maxwidth:
|
| 493 |
+
vis = vis[...,:self.maxwidth]
|
| 494 |
+
|
| 495 |
+
if only_return:
|
| 496 |
+
return vis
|
| 497 |
+
else:
|
| 498 |
+
self.summ_gif(name, vis)
|
| 499 |
+
|
| 500 |
+
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
|
| 501 |
+
if self.save_this:
|
| 502 |
+
|
| 503 |
+
if bev:
|
| 504 |
+
B, C, H, _, W = list(im.shape)
|
| 505 |
+
if max_along_y:
|
| 506 |
+
im = torch.max(im, dim=3)[0]
|
| 507 |
+
else:
|
| 508 |
+
im = torch.mean(im, dim=3)
|
| 509 |
+
elif fro:
|
| 510 |
+
B, C, _, H, W = list(im.shape)
|
| 511 |
+
if max_along_y:
|
| 512 |
+
im = torch.max(im, dim=2)[0]
|
| 513 |
+
else:
|
| 514 |
+
im = torch.mean(im, dim=2)
|
| 515 |
+
else:
|
| 516 |
+
B, C, H, W = list(im.shape)
|
| 517 |
+
|
| 518 |
+
im = im[0:1] # just the first one
|
| 519 |
+
assert(C==1)
|
| 520 |
+
|
| 521 |
+
if logvis and max_val:
|
| 522 |
+
max_val = np.log(max_val)
|
| 523 |
+
im = torch.log(im)
|
| 524 |
+
im = torch.clamp(im, 0, max_val)
|
| 525 |
+
im = im/max_val
|
| 526 |
+
norm = False
|
| 527 |
+
elif max_val:
|
| 528 |
+
im = torch.clamp(im, 0, max_val)/max_val
|
| 529 |
+
norm = False
|
| 530 |
+
|
| 531 |
+
vis = oned2inferno(im, norm=norm)
|
| 532 |
+
if W > self.maxwidth:
|
| 533 |
+
vis = vis[...,:self.maxwidth]
|
| 534 |
+
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
|
| 538 |
+
if self.save_this:
|
| 539 |
+
if valids is not None:
|
| 540 |
+
valids = torch.stack(valids, dim=1)
|
| 541 |
+
|
| 542 |
+
feats = torch.stack(feats, dim=1)
|
| 543 |
+
# feats leads with B x S x C
|
| 544 |
+
|
| 545 |
+
if feats.ndim==6:
|
| 546 |
+
|
| 547 |
+
# feats is B x S x C x D x H x W
|
| 548 |
+
if fro:
|
| 549 |
+
reduce_dim = 3
|
| 550 |
+
else:
|
| 551 |
+
reduce_dim = 4
|
| 552 |
+
|
| 553 |
+
if valids is None:
|
| 554 |
+
feats = torch.mean(feats, dim=reduce_dim)
|
| 555 |
+
else:
|
| 556 |
+
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
|
| 557 |
+
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
|
| 558 |
+
|
| 559 |
+
B, S, C, D, W = list(feats.size())
|
| 560 |
+
|
| 561 |
+
if not pca:
|
| 562 |
+
# feats leads with B x S x C
|
| 563 |
+
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
|
| 564 |
+
# feats leads with B x S x 1
|
| 565 |
+
feats = torch.unbind(feats, dim=1)
|
| 566 |
+
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
| 567 |
+
|
| 568 |
+
else:
|
| 569 |
+
__p = lambda x: utils.basic.pack_seqdim(x, B)
|
| 570 |
+
__u = lambda x: utils.basic.unpack_seqdim(x, B)
|
| 571 |
+
|
| 572 |
+
feats_ = __p(feats)
|
| 573 |
+
|
| 574 |
+
if valids is None:
|
| 575 |
+
feats_pca_ = get_feat_pca(feats_)
|
| 576 |
+
else:
|
| 577 |
+
valids_ = __p(valids)
|
| 578 |
+
feats_pca_ = get_feat_pca(feats_, valids)
|
| 579 |
+
|
| 580 |
+
feats_pca = __u(feats_pca_)
|
| 581 |
+
|
| 582 |
+
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
| 583 |
+
|
| 584 |
+
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
|
| 585 |
+
if self.save_this:
|
| 586 |
+
if feat.ndim==5: # B x C x D x H x W
|
| 587 |
+
|
| 588 |
+
if bev:
|
| 589 |
+
reduce_axis = 3
|
| 590 |
+
elif fro:
|
| 591 |
+
reduce_axis = 2
|
| 592 |
+
else:
|
| 593 |
+
# default to bev
|
| 594 |
+
reduce_axis = 3
|
| 595 |
+
|
| 596 |
+
if valid is None:
|
| 597 |
+
feat = torch.mean(feat, dim=reduce_axis)
|
| 598 |
+
else:
|
| 599 |
+
valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
|
| 600 |
+
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
|
| 601 |
+
|
| 602 |
+
B, C, D, W = list(feat.shape)
|
| 603 |
+
|
| 604 |
+
if not pca:
|
| 605 |
+
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
|
| 606 |
+
# feat is B x 1 x D x W
|
| 607 |
+
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
| 608 |
+
else:
|
| 609 |
+
feat_pca = get_feat_pca(feat, valid)
|
| 610 |
+
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
| 611 |
+
|
| 612 |
+
def summ_scalar(self, name, value):
|
| 613 |
+
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
|
| 614 |
+
value = value.detach().cpu().numpy()
|
| 615 |
+
if not np.isnan(value):
|
| 616 |
+
if (self.log_freq == 1):
|
| 617 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
| 618 |
+
elif self.save_this or self.save_scalar:
|
| 619 |
+
self.writer.add_scalar(name, value, global_step=self.global_step)
|
| 620 |
+
|
| 621 |
+
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
|
| 622 |
+
# trajs is B, S, N, 2
|
| 623 |
+
# rgbs is B, S, C, H, W
|
| 624 |
+
B, S, C, H, W = rgbs.shape
|
| 625 |
+
B, S2, N, D = trajs.shape
|
| 626 |
+
assert(S==S2)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
rgbs = rgbs[0] # S, C, H, W
|
| 630 |
+
trajs = trajs[0] # S, N, 2
|
| 631 |
+
if valids is None:
|
| 632 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
| 633 |
+
else:
|
| 634 |
+
valids = valids[0]
|
| 635 |
+
|
| 636 |
+
if visibs is None:
|
| 637 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
| 638 |
+
else:
|
| 639 |
+
visibs = visibs[0]
|
| 640 |
+
|
| 641 |
+
if vals is not None:
|
| 642 |
+
vals = vals[0] # N
|
| 643 |
+
# print('vals', vals.shape)
|
| 644 |
+
|
| 645 |
+
if N > max_show:
|
| 646 |
+
inds = np.random.choice(N, max_show)
|
| 647 |
+
trajs = trajs[:,inds]
|
| 648 |
+
valids = valids[:,inds]
|
| 649 |
+
visibs = visibs[:,inds]
|
| 650 |
+
if vals is not None:
|
| 651 |
+
vals = vals[inds]
|
| 652 |
+
N = trajs.shape[1]
|
| 653 |
+
|
| 654 |
+
trajs = trajs.clamp(-16, W+16)
|
| 655 |
+
|
| 656 |
+
rgbs_color = []
|
| 657 |
+
for rgb in rgbs:
|
| 658 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
| 659 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
| 660 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
| 661 |
+
|
| 662 |
+
for i in range(min(N, max_show)):
|
| 663 |
+
if cmap=='onediff' and i==0:
|
| 664 |
+
cmap_ = 'spring'
|
| 665 |
+
elif cmap=='onediff':
|
| 666 |
+
cmap_ = 'winter'
|
| 667 |
+
else:
|
| 668 |
+
cmap_ = cmap
|
| 669 |
+
traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
|
| 670 |
+
valid = valids[:,i].long().detach().cpu().numpy() # S
|
| 671 |
+
|
| 672 |
+
# print('traj', traj.shape)
|
| 673 |
+
# print('valid', valid.shape)
|
| 674 |
+
|
| 675 |
+
if vals is not None:
|
| 676 |
+
# val = vals[:,i].float().detach().cpu().numpy() # []
|
| 677 |
+
val = vals[i].float().detach().cpu().numpy() # []
|
| 678 |
+
# print('val', val.shape)
|
| 679 |
+
else:
|
| 680 |
+
val = None
|
| 681 |
+
|
| 682 |
+
for t in range(S):
|
| 683 |
+
if valid[t]:
|
| 684 |
+
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
|
| 685 |
+
|
| 686 |
+
for i in range(min(N, max_show)):
|
| 687 |
+
if cmap=='onediff' and i==0:
|
| 688 |
+
cmap_ = 'spring'
|
| 689 |
+
elif cmap=='onediff':
|
| 690 |
+
cmap_ = 'winter'
|
| 691 |
+
else:
|
| 692 |
+
cmap_ = cmap
|
| 693 |
+
traj = trajs[:,i] # S,2
|
| 694 |
+
vis = visibs[:,i].round() # S
|
| 695 |
+
valid = valids[:,i] # S
|
| 696 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
| 697 |
+
|
| 698 |
+
rgbs = []
|
| 699 |
+
for rgb in rgbs_color:
|
| 700 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
| 701 |
+
rgbs.append(preprocess_color(rgb))
|
| 702 |
+
|
| 703 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
| 704 |
+
|
| 705 |
+
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
|
| 706 |
+
# trajs is B, S, N, 2
|
| 707 |
+
# rgbs is B, S, C, H, W
|
| 708 |
+
B, S, C, H, W = rgbs.shape
|
| 709 |
+
B, S2, N, D = trajs.shape
|
| 710 |
+
assert(S==S2)
|
| 711 |
+
|
| 712 |
+
rgbs = rgbs[0] # S, C, H, W
|
| 713 |
+
trajs = trajs[0] # S, N, 2
|
| 714 |
+
visibles = visibles[0] # S, N
|
| 715 |
+
if valids is None:
|
| 716 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
| 717 |
+
else:
|
| 718 |
+
valids = valids[0]
|
| 719 |
+
|
| 720 |
+
rgbs_color = []
|
| 721 |
+
for rgb in rgbs:
|
| 722 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
| 723 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
| 724 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
| 725 |
+
|
| 726 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
| 727 |
+
visibles = visibles.float().detach().cpu().numpy() # S, N
|
| 728 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
| 729 |
+
|
| 730 |
+
for i in range(min(N, max_show)):
|
| 731 |
+
if cmap=='onediff' and i==0:
|
| 732 |
+
cmap_ = 'spring'
|
| 733 |
+
elif cmap=='onediff':
|
| 734 |
+
cmap_ = 'winter'
|
| 735 |
+
else:
|
| 736 |
+
cmap_ = cmap
|
| 737 |
+
traj = trajs[:,i] # S,2
|
| 738 |
+
vis = visibles[:,i] # S
|
| 739 |
+
valid = valids[:,i] # S
|
| 740 |
+
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
|
| 741 |
+
|
| 742 |
+
for i in range(min(N, max_show)):
|
| 743 |
+
if cmap=='onediff' and i==0:
|
| 744 |
+
cmap_ = 'spring'
|
| 745 |
+
elif cmap=='onediff':
|
| 746 |
+
cmap_ = 'winter'
|
| 747 |
+
else:
|
| 748 |
+
cmap_ = cmap
|
| 749 |
+
traj = trajs[:,i] # S,2
|
| 750 |
+
vis = visibles[:,i] # S
|
| 751 |
+
valid = valids[:,i] # S
|
| 752 |
+
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
|
| 753 |
+
|
| 754 |
+
rgbs = []
|
| 755 |
+
for rgb in rgbs_color:
|
| 756 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
| 757 |
+
rgbs.append(preprocess_color(rgb))
|
| 758 |
+
|
| 759 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
| 760 |
+
|
| 761 |
+
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
|
| 762 |
+
# trajs is B, S, N, 2
|
| 763 |
+
# rgb is B, C, H, W
|
| 764 |
+
B, C, H, W = rgb.shape
|
| 765 |
+
B, S, N, D = trajs.shape
|
| 766 |
+
|
| 767 |
+
rgb = rgb[0] # S, C, H, W
|
| 768 |
+
trajs = trajs[0] # S, N, 2
|
| 769 |
+
|
| 770 |
+
if valids is None:
|
| 771 |
+
valids = torch.ones_like(trajs[:,:,0])
|
| 772 |
+
else:
|
| 773 |
+
valids = valids[0]
|
| 774 |
+
|
| 775 |
+
rgb_color = back2color(rgb).detach().cpu().numpy()
|
| 776 |
+
rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
|
| 777 |
+
|
| 778 |
+
# using maxdist will dampen the colors for short motions
|
| 779 |
+
# norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
|
| 780 |
+
# maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
|
| 781 |
+
maxdist = None
|
| 782 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
| 783 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
| 784 |
+
|
| 785 |
+
if N > max_show:
|
| 786 |
+
inds = np.random.choice(N, max_show)
|
| 787 |
+
trajs = trajs[:,inds]
|
| 788 |
+
valids = valids[:,inds]
|
| 789 |
+
N = trajs.shape[1]
|
| 790 |
+
|
| 791 |
+
for i in range(min(N, max_show)):
|
| 792 |
+
if cmap=='onediff' and i==0:
|
| 793 |
+
cmap_ = 'spring'
|
| 794 |
+
elif cmap=='onediff':
|
| 795 |
+
cmap_ = 'winter'
|
| 796 |
+
else:
|
| 797 |
+
cmap_ = cmap
|
| 798 |
+
traj = trajs[:,i] # S, 2
|
| 799 |
+
valid = valids[:,i] # S
|
| 800 |
+
if valid[0]==1:
|
| 801 |
+
traj = traj[valid>0]
|
| 802 |
+
rgb_color = self.draw_traj_on_image_py(
|
| 803 |
+
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
|
| 804 |
+
|
| 805 |
+
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
|
| 806 |
+
rgb = preprocess_color(rgb_color)
|
| 807 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
| 808 |
+
|
| 809 |
+
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
|
| 810 |
+
# all inputs are numpy tensors
|
| 811 |
+
# rgb is 3 x H x W
|
| 812 |
+
# traj is S x 2
|
| 813 |
+
|
| 814 |
+
H, W, C = rgb.shape
|
| 815 |
+
assert(C==3)
|
| 816 |
+
|
| 817 |
+
rgb = rgb.astype(np.uint8).copy()
|
| 818 |
+
|
| 819 |
+
S1, D = traj.shape
|
| 820 |
+
assert(D==2)
|
| 821 |
+
|
| 822 |
+
color_map = cm.get_cmap(cmap)
|
| 823 |
+
S1, D = traj.shape
|
| 824 |
+
|
| 825 |
+
for s in range(S1):
|
| 826 |
+
if val is not None:
|
| 827 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
| 828 |
+
else:
|
| 829 |
+
if maxdist is not None:
|
| 830 |
+
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
|
| 831 |
+
color = np.array(color_map(val)[:3]) * 255 # rgb
|
| 832 |
+
else:
|
| 833 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
| 834 |
+
|
| 835 |
+
if show_lines and s<(S1-1):
|
| 836 |
+
cv2.line(rgb,
|
| 837 |
+
(int(traj[s,0]), int(traj[s,1])),
|
| 838 |
+
(int(traj[s+1,0]), int(traj[s+1,1])),
|
| 839 |
+
color,
|
| 840 |
+
linewidth,
|
| 841 |
+
cv2.LINE_AA)
|
| 842 |
+
if show_dots:
|
| 843 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
|
| 844 |
+
|
| 845 |
+
# if maxdist is not None:
|
| 846 |
+
# val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
|
| 847 |
+
# color = np.array(color_map(val)[:3]) * 255 # rgb
|
| 848 |
+
# else:
|
| 849 |
+
# # draw the endpoint of traj, using the next color (which may be the last color)
|
| 850 |
+
# color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
|
| 851 |
+
|
| 852 |
+
# # emphasize endpoint
|
| 853 |
+
# cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
|
| 854 |
+
|
| 855 |
+
return rgb
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
|
| 859 |
+
# all inputs are numpy tensors
|
| 860 |
+
# rgbs is a list of H,W,3
|
| 861 |
+
# traj is S,2
|
| 862 |
+
H, W, C = rgbs[0].shape
|
| 863 |
+
assert(C==3)
|
| 864 |
+
|
| 865 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
| 866 |
+
|
| 867 |
+
S1, D = traj.shape
|
| 868 |
+
assert(D==2)
|
| 869 |
+
|
| 870 |
+
x = int(np.clip(traj[0,0], 0, W-1))
|
| 871 |
+
y = int(np.clip(traj[0,1], 0, H-1))
|
| 872 |
+
color = rgbs[0][y,x]
|
| 873 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
| 874 |
+
for s in range(S):
|
| 875 |
+
# bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
|
| 876 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
|
| 877 |
+
cv2.polylines(rgbs[s],
|
| 878 |
+
[traj[:s+1]],
|
| 879 |
+
False,
|
| 880 |
+
color,
|
| 881 |
+
linewidth,
|
| 882 |
+
cv2.LINE_AA)
|
| 883 |
+
return rgbs
|
| 884 |
+
|
| 885 |
+
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
|
| 886 |
+
# all inputs are numpy tensors
|
| 887 |
+
# rgbs is a list of 3,H,W
|
| 888 |
+
# xy is N,2
|
| 889 |
+
H, W, C = rgb.shape
|
| 890 |
+
assert(C==3)
|
| 891 |
+
|
| 892 |
+
rgb = rgb.astype(np.uint8).copy()
|
| 893 |
+
|
| 894 |
+
N, D = xy.shape
|
| 895 |
+
assert(D==2)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
xy = xy.astype(np.float32)
|
| 899 |
+
xy[:,0] = np.clip(xy[:,0], 0, W-1)
|
| 900 |
+
xy[:,1] = np.clip(xy[:,1], 0, H-1)
|
| 901 |
+
xy = xy.astype(np.int32)
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
if colors is None:
|
| 906 |
+
colors = get_n_colors(N)
|
| 907 |
+
|
| 908 |
+
for n in range(N):
|
| 909 |
+
color = colors[n]
|
| 910 |
+
# print('color', color)
|
| 911 |
+
# color = (color[0]*255).astype(np.uint8)
|
| 912 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
| 913 |
+
|
| 914 |
+
# x = int(np.clip(xy[0,0], 0, W-1))
|
| 915 |
+
# y = int(np.clip(xy[0,1], 0, H-1))
|
| 916 |
+
# color_ = rgbs[0][y,x]
|
| 917 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
| 918 |
+
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
|
| 919 |
+
|
| 920 |
+
cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
|
| 921 |
+
# vis_color = int(np.squeeze(vis[s])*255)
|
| 922 |
+
# vis_color = (vis_color,vis_color,vis_color)
|
| 923 |
+
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
|
| 924 |
+
return rgb
|
| 925 |
+
|
| 926 |
+
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
|
| 927 |
+
# all inputs are numpy tensors
|
| 928 |
+
# rgbs is a list of 3,H,W
|
| 929 |
+
# traj is S,2
|
| 930 |
+
H, W, C = rgbs[0].shape
|
| 931 |
+
assert(C==3)
|
| 932 |
+
|
| 933 |
+
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
|
| 934 |
+
|
| 935 |
+
S1, D = traj.shape
|
| 936 |
+
assert(D==2)
|
| 937 |
+
|
| 938 |
+
if cmap is None:
|
| 939 |
+
bremm = ColorMap2d()
|
| 940 |
+
traj_ = traj[0:1].astype(np.float32)
|
| 941 |
+
traj_[:,0] /= float(W)
|
| 942 |
+
traj_[:,1] /= float(H)
|
| 943 |
+
color = bremm(traj_)
|
| 944 |
+
# print('color', color)
|
| 945 |
+
color = (color[0]*255).astype(np.uint8)
|
| 946 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
| 947 |
+
|
| 948 |
+
for s in range(S):
|
| 949 |
+
if cmap is not None:
|
| 950 |
+
color_map = cm.get_cmap(cmap)
|
| 951 |
+
# color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
|
| 952 |
+
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
|
| 953 |
+
# color = color.astype(np.uint8)
|
| 954 |
+
# color = (color[0], color[1], color[2])
|
| 955 |
+
# print('color', color)
|
| 956 |
+
# import ipdb; ipdb.set_trace()
|
| 957 |
+
|
| 958 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
|
| 959 |
+
vis_color = int(np.squeeze(vis[s])*255)
|
| 960 |
+
vis_color = (vis_color,vis_color,vis_color)
|
| 961 |
+
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
|
| 962 |
+
|
| 963 |
+
return rgbs
|
| 964 |
+
|
| 965 |
+
def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
|
| 966 |
+
# trajs is B, S, N, 2
|
| 967 |
+
# rgbs is B, S, C, H, W
|
| 968 |
+
B, C, H, W = rgb.shape
|
| 969 |
+
B, S, N, D = trajs.shape
|
| 970 |
+
|
| 971 |
+
rgb = rgb[0] # C, H, W
|
| 972 |
+
trajs = trajs[0] # S, N, 2
|
| 973 |
+
if valids is None:
|
| 974 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
| 975 |
+
else:
|
| 976 |
+
valids = valids[0]
|
| 977 |
+
if visibs is None:
|
| 978 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
| 979 |
+
else:
|
| 980 |
+
visibs = visibs[0]
|
| 981 |
+
|
| 982 |
+
trajs = trajs.clamp(-16, W+16)
|
| 983 |
+
|
| 984 |
+
if N > max_show:
|
| 985 |
+
inds = np.random.choice(N, max_show)
|
| 986 |
+
trajs = trajs[:,inds]
|
| 987 |
+
valids = valids[:,inds]
|
| 988 |
+
visibs = visibs[:,inds]
|
| 989 |
+
N = trajs.shape[1]
|
| 990 |
+
|
| 991 |
+
if not already_sorted:
|
| 992 |
+
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
|
| 993 |
+
trajs = trajs[:,inds]
|
| 994 |
+
valids = valids[:,inds]
|
| 995 |
+
visibs = visibs[:,inds]
|
| 996 |
+
|
| 997 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
| 998 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
| 999 |
+
|
| 1000 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
| 1001 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
| 1002 |
+
visibs = visibs.long().detach().cpu().numpy() # S, N
|
| 1003 |
+
|
| 1004 |
+
rgb = rgb.astype(np.uint8).copy()
|
| 1005 |
+
|
| 1006 |
+
for i in range(min(N, max_show)):
|
| 1007 |
+
if cmap=='onediff' and i==0:
|
| 1008 |
+
cmap_ = 'spring'
|
| 1009 |
+
elif cmap=='onediff':
|
| 1010 |
+
cmap_ = 'winter'
|
| 1011 |
+
else:
|
| 1012 |
+
cmap_ = cmap
|
| 1013 |
+
traj = trajs[:,i] # S,2
|
| 1014 |
+
valid = valids[:,i] # S
|
| 1015 |
+
visib = visibs[:,i] # S
|
| 1016 |
+
|
| 1017 |
+
if colors is None:
|
| 1018 |
+
ii = i/(1e-4+N-1.0)
|
| 1019 |
+
color_map = cm.get_cmap(cmap)
|
| 1020 |
+
color = np.array(color_map(ii)[:3]) * 255 # rgb
|
| 1021 |
+
else:
|
| 1022 |
+
color = np.array(colors[i]).astype(np.int64)
|
| 1023 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
| 1024 |
+
|
| 1025 |
+
for s in range(S):
|
| 1026 |
+
if valid[s]:
|
| 1027 |
+
if visib[s]:
|
| 1028 |
+
thickness = -1
|
| 1029 |
+
else:
|
| 1030 |
+
thickness = 2
|
| 1031 |
+
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
|
| 1032 |
+
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
|
| 1033 |
+
rgb = preprocess_color(rgb)
|
| 1034 |
+
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
|
| 1035 |
+
|
| 1036 |
+
def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
|
| 1037 |
+
# trajs is B, S, N, 2
|
| 1038 |
+
# rgbs is B, S, C, H, W
|
| 1039 |
+
B, S, C, H, W = rgbs.shape
|
| 1040 |
+
B, S2, N, D = trajs.shape
|
| 1041 |
+
assert(S==S2)
|
| 1042 |
+
|
| 1043 |
+
rgbs = rgbs[0] # S, C, H, W
|
| 1044 |
+
trajs = trajs[0] # S, N, 2
|
| 1045 |
+
if valids is None:
|
| 1046 |
+
valids = torch.ones_like(trajs[:,:,0]) # S, N
|
| 1047 |
+
else:
|
| 1048 |
+
valids = valids[0]
|
| 1049 |
+
if visibs is None:
|
| 1050 |
+
visibs = torch.ones_like(trajs[:,:,0]) # S, N
|
| 1051 |
+
else:
|
| 1052 |
+
visibs = visibs[0]
|
| 1053 |
+
|
| 1054 |
+
if N > max_show:
|
| 1055 |
+
inds = np.random.choice(N, max_show)
|
| 1056 |
+
trajs = trajs[:,inds]
|
| 1057 |
+
valids = valids[:,inds]
|
| 1058 |
+
visibs = visibs[:,inds]
|
| 1059 |
+
N = trajs.shape[1]
|
| 1060 |
+
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
|
| 1061 |
+
trajs = trajs[:,inds]
|
| 1062 |
+
valids = valids[:,inds]
|
| 1063 |
+
visibs = visibs[:,inds]
|
| 1064 |
+
|
| 1065 |
+
rgbs_color = []
|
| 1066 |
+
for rgb in rgbs:
|
| 1067 |
+
rgb = back2color(rgb).detach().cpu().numpy()
|
| 1068 |
+
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
|
| 1069 |
+
rgbs_color.append(rgb) # each element 3 x H x W
|
| 1070 |
+
|
| 1071 |
+
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
|
| 1072 |
+
valids = valids.long().detach().cpu().numpy() # S, N
|
| 1073 |
+
visibs = visibs.long().detach().cpu().numpy() # S, N
|
| 1074 |
+
|
| 1075 |
+
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
|
| 1076 |
+
|
| 1077 |
+
for i in range(min(N, max_show)):
|
| 1078 |
+
traj = trajs[:,i] # S,2
|
| 1079 |
+
valid = valids[:,i] # S
|
| 1080 |
+
visib = visibs[:,i] # S
|
| 1081 |
+
|
| 1082 |
+
if colors is None:
|
| 1083 |
+
ii = i/(1e-4+N-1.0)
|
| 1084 |
+
color_map = cm.get_cmap(cmap)
|
| 1085 |
+
color = np.array(color_map(ii)[:3]) * 255 # rgb
|
| 1086 |
+
else:
|
| 1087 |
+
color = np.array(colors[i]).astype(np.int64)
|
| 1088 |
+
color = (int(color[0]),int(color[1]),int(color[2]))
|
| 1089 |
+
|
| 1090 |
+
for s in range(S):
|
| 1091 |
+
if valid[s]:
|
| 1092 |
+
if visib[s]:
|
| 1093 |
+
thickness = -1
|
| 1094 |
+
else:
|
| 1095 |
+
thickness = 2
|
| 1096 |
+
cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
|
| 1097 |
+
rgbs = []
|
| 1098 |
+
for rgb in rgbs_color:
|
| 1099 |
+
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
|
| 1100 |
+
rgbs.append(preprocess_color(rgb))
|
| 1101 |
+
|
| 1102 |
+
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
|
| 1103 |
+
|
utils/loss.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import List
|
| 5 |
+
import utils.basic
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def sequence_loss(
|
| 9 |
+
flow_preds,
|
| 10 |
+
flow_gt,
|
| 11 |
+
valids,
|
| 12 |
+
vis=None,
|
| 13 |
+
gamma=0.8,
|
| 14 |
+
use_huber_loss=False,
|
| 15 |
+
loss_only_for_visible=False,
|
| 16 |
+
):
|
| 17 |
+
"""Loss function defined over sequence of flow predictions"""
|
| 18 |
+
total_flow_loss = 0.0
|
| 19 |
+
for j in range(len(flow_gt)):
|
| 20 |
+
B, S, N, D = flow_gt[j].shape
|
| 21 |
+
B, S2, N = valids[j].shape
|
| 22 |
+
assert S == S2
|
| 23 |
+
n_predictions = len(flow_preds[j])
|
| 24 |
+
flow_loss = 0.0
|
| 25 |
+
for i in range(n_predictions):
|
| 26 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
| 27 |
+
flow_pred = flow_preds[j][i]
|
| 28 |
+
if use_huber_loss:
|
| 29 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
|
| 30 |
+
else:
|
| 31 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
| 32 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
| 33 |
+
valid_ = valids[j].clone()
|
| 34 |
+
if loss_only_for_visible:
|
| 35 |
+
valid_ = valid_ * vis[j]
|
| 36 |
+
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
|
| 37 |
+
flow_loss = flow_loss / n_predictions
|
| 38 |
+
total_flow_loss += flow_loss
|
| 39 |
+
return total_flow_loss / len(flow_gt)
|
| 40 |
+
|
| 41 |
+
def sequence_loss_dense(
|
| 42 |
+
flow_preds,
|
| 43 |
+
flow_gt,
|
| 44 |
+
valids,
|
| 45 |
+
vis=None,
|
| 46 |
+
gamma=0.8,
|
| 47 |
+
use_huber_loss=False,
|
| 48 |
+
loss_only_for_visible=False,
|
| 49 |
+
):
|
| 50 |
+
"""Loss function defined over sequence of flow predictions"""
|
| 51 |
+
total_flow_loss = 0.0
|
| 52 |
+
for j in range(len(flow_gt)):
|
| 53 |
+
# print('flow_gt[j]', flow_gt[j].shape)
|
| 54 |
+
B, S, D, H, W = flow_gt[j].shape
|
| 55 |
+
B, S2, _, H, W = valids[j].shape
|
| 56 |
+
assert S == S2
|
| 57 |
+
n_predictions = len(flow_preds[j])
|
| 58 |
+
flow_loss = 0.0
|
| 59 |
+
# import ipdb; ipdb.set_trace()
|
| 60 |
+
for i in range(n_predictions):
|
| 61 |
+
# print('flow_e[j][i]', flow_preds[j][i].shape)
|
| 62 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
| 63 |
+
flow_pred = flow_preds[j][i] # B,S,2,H,W
|
| 64 |
+
if use_huber_loss:
|
| 65 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
|
| 66 |
+
else:
|
| 67 |
+
i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
|
| 68 |
+
i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
|
| 69 |
+
valid_ = valids[j].reshape(B,S,H,W)
|
| 70 |
+
# print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
|
| 71 |
+
# print(' (%d,%d) valid_' % (i,j), valid_.shape)
|
| 72 |
+
if loss_only_for_visible:
|
| 73 |
+
valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
|
| 74 |
+
flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
|
| 75 |
+
# import ipdb; ipdb.set_trace()
|
| 76 |
+
flow_loss = flow_loss / n_predictions
|
| 77 |
+
total_flow_loss += flow_loss
|
| 78 |
+
return total_flow_loss / len(flow_gt)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def huber_loss(x, y, delta=1.0):
|
| 82 |
+
"""Calculate element-wise Huber loss between x and y"""
|
| 83 |
+
diff = x - y
|
| 84 |
+
abs_diff = diff.abs()
|
| 85 |
+
flag = (abs_diff <= delta).float()
|
| 86 |
+
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
|
| 90 |
+
total_bce_loss = 0.0
|
| 91 |
+
# all_vis_preds = [torch.stack(vp) for vp in vis_preds]
|
| 92 |
+
# all_vis_preds = torch.stack(all_vis_preds)
|
| 93 |
+
# utils.basic.print_stats('all_vis_preds', all_vis_preds)
|
| 94 |
+
for j in range(len(vis_preds)):
|
| 95 |
+
n_predictions = len(vis_preds[j])
|
| 96 |
+
bce_loss = 0.0
|
| 97 |
+
for i in range(n_predictions):
|
| 98 |
+
# utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
|
| 99 |
+
# utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
|
| 100 |
+
if use_logits:
|
| 101 |
+
loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
|
| 102 |
+
else:
|
| 103 |
+
loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
|
| 104 |
+
if valids is None:
|
| 105 |
+
bce_loss += loss.mean()
|
| 106 |
+
else:
|
| 107 |
+
bce_loss += (loss * valids[j]).mean()
|
| 108 |
+
bce_loss = bce_loss / n_predictions
|
| 109 |
+
total_bce_loss += bce_loss
|
| 110 |
+
return total_bce_loss / len(vis_preds)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# def sequence_BCE_loss_dense(vis_preds, vis_gts):
|
| 114 |
+
# total_bce_loss = 0.0
|
| 115 |
+
# for j in range(len(vis_preds)):
|
| 116 |
+
# n_predictions = len(vis_preds[j])
|
| 117 |
+
# bce_loss = 0.0
|
| 118 |
+
# for i in range(n_predictions):
|
| 119 |
+
# vis_e = vis_preds[j][i]
|
| 120 |
+
# vis_g = vis_gts[j]
|
| 121 |
+
# print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
|
| 122 |
+
# vis_loss = F.binary_cross_entropy(vis_e, vis_g)
|
| 123 |
+
# bce_loss += vis_loss
|
| 124 |
+
# bce_loss = bce_loss / n_predictions
|
| 125 |
+
# total_bce_loss += bce_loss
|
| 126 |
+
# return total_bce_loss / len(vis_preds)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def sequence_prob_loss(
|
| 130 |
+
tracks: torch.Tensor,
|
| 131 |
+
confidence: torch.Tensor,
|
| 132 |
+
target_points: torch.Tensor,
|
| 133 |
+
visibility: torch.Tensor,
|
| 134 |
+
expected_dist_thresh: float = 12.0,
|
| 135 |
+
use_logits=False,
|
| 136 |
+
):
|
| 137 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
| 138 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
| 139 |
+
# them as occluded will actually improve Jaccard metrics and give
|
| 140 |
+
# qualitatively better results.
|
| 141 |
+
total_logprob_loss = 0.0
|
| 142 |
+
for j in range(len(tracks)):
|
| 143 |
+
n_predictions = len(tracks[j])
|
| 144 |
+
logprob_loss = 0.0
|
| 145 |
+
for i in range(n_predictions):
|
| 146 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
|
| 147 |
+
valid = (err <= expected_dist_thresh**2).float()
|
| 148 |
+
if use_logits:
|
| 149 |
+
loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
|
| 150 |
+
else:
|
| 151 |
+
loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
|
| 152 |
+
loss *= visibility[j]
|
| 153 |
+
loss = torch.mean(loss, dim=[1, 2])
|
| 154 |
+
logprob_loss += loss
|
| 155 |
+
logprob_loss = logprob_loss / n_predictions
|
| 156 |
+
total_logprob_loss += logprob_loss
|
| 157 |
+
return total_logprob_loss / len(tracks)
|
| 158 |
+
|
| 159 |
+
def sequence_prob_loss_dense(
|
| 160 |
+
tracks: torch.Tensor,
|
| 161 |
+
confidence: torch.Tensor,
|
| 162 |
+
target_points: torch.Tensor,
|
| 163 |
+
visibility: torch.Tensor,
|
| 164 |
+
expected_dist_thresh: float = 12.0,
|
| 165 |
+
use_logits=False,
|
| 166 |
+
):
|
| 167 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
| 168 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
| 169 |
+
# them as occluded will actually improve Jaccard metrics and give
|
| 170 |
+
# qualitatively better results.
|
| 171 |
+
|
| 172 |
+
# all_confidence = [torch.stack(vp) for vp in confidence]
|
| 173 |
+
# all_confidence = torch.stack(all_confidence)
|
| 174 |
+
# utils.basic.print_stats('all_confidence', all_confidence)
|
| 175 |
+
|
| 176 |
+
total_logprob_loss = 0.0
|
| 177 |
+
for j in range(len(tracks)):
|
| 178 |
+
n_predictions = len(tracks[j])
|
| 179 |
+
logprob_loss = 0.0
|
| 180 |
+
for i in range(n_predictions):
|
| 181 |
+
# print('trajs_e', tracks[j][i].shape)
|
| 182 |
+
# print('trajs_g', target_points[j].shape)
|
| 183 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
|
| 184 |
+
positive = (err <= expected_dist_thresh**2).float()
|
| 185 |
+
# print('conf', confidence[j][i].shape, 'positive', positive.shape)
|
| 186 |
+
if use_logits:
|
| 187 |
+
loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
|
| 188 |
+
else:
|
| 189 |
+
loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
|
| 190 |
+
loss *= visibility[j].squeeze(2) # B,S,H,W
|
| 191 |
+
loss = torch.mean(loss, dim=[1,2,3])
|
| 192 |
+
logprob_loss += loss
|
| 193 |
+
logprob_loss = logprob_loss / n_predictions
|
| 194 |
+
total_logprob_loss += logprob_loss
|
| 195 |
+
return total_logprob_loss / len(tracks)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def masked_mean(data, mask, dim):
|
| 199 |
+
if mask is None:
|
| 200 |
+
return data.mean(dim=dim, keepdim=True)
|
| 201 |
+
mask = mask.float()
|
| 202 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 203 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 204 |
+
mask_sum, min=1.0
|
| 205 |
+
)
|
| 206 |
+
return mask_mean
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
| 210 |
+
if mask is None:
|
| 211 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
| 212 |
+
mask = mask.float()
|
| 213 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 214 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 215 |
+
mask_sum, min=1.0
|
| 216 |
+
)
|
| 217 |
+
mask_var = torch.sum(
|
| 218 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
| 219 |
+
) / torch.clamp(mask_sum, min=1.0)
|
| 220 |
+
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
|
utils/misc.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, positions):
|
| 5 |
+
assert embed_dim % 2 == 0
|
| 6 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 7 |
+
omega /= embed_dim / 2.0
|
| 8 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 9 |
+
|
| 10 |
+
positions = positions.reshape(-1) # (M,)
|
| 11 |
+
out = torch.einsum("m,d->md", positions, omega) # (M, D/2), outer product
|
| 12 |
+
|
| 13 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 14 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 15 |
+
|
| 16 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 17 |
+
return emb[None].float()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimplePool():
|
| 21 |
+
def __init__(self, pool_size, version='pt', min_size=1):
|
| 22 |
+
self.pool_size = pool_size
|
| 23 |
+
self.version = version
|
| 24 |
+
self.items = []
|
| 25 |
+
self.min_size = min_size
|
| 26 |
+
|
| 27 |
+
if not (version=='pt' or version=='np'):
|
| 28 |
+
print('version = %s; please choose pt or np')
|
| 29 |
+
assert(False) # please choose pt or np
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.items)
|
| 33 |
+
|
| 34 |
+
def mean(self, min_size=None):
|
| 35 |
+
if min_size is None:
|
| 36 |
+
pool_size_thresh = self.min_size
|
| 37 |
+
elif min_size=='half':
|
| 38 |
+
pool_size_thresh = self.pool_size/2
|
| 39 |
+
else:
|
| 40 |
+
pool_size_thresh = min_size
|
| 41 |
+
|
| 42 |
+
if self.version=='np':
|
| 43 |
+
if len(self.items) >= pool_size_thresh:
|
| 44 |
+
return np.sum(self.items)/float(len(self.items))
|
| 45 |
+
else:
|
| 46 |
+
return np.nan
|
| 47 |
+
if self.version=='pt':
|
| 48 |
+
if len(self.items) >= pool_size_thresh:
|
| 49 |
+
return torch.sum(self.items)/float(len(self.items))
|
| 50 |
+
else:
|
| 51 |
+
return torch.from_numpy(np.nan)
|
| 52 |
+
|
| 53 |
+
def sample(self, with_replacement=True):
|
| 54 |
+
idx = np.random.randint(len(self.items))
|
| 55 |
+
if with_replacement:
|
| 56 |
+
return self.items[idx]
|
| 57 |
+
else:
|
| 58 |
+
return self.items.pop(idx)
|
| 59 |
+
|
| 60 |
+
def fetch(self, num=None):
|
| 61 |
+
if self.version=='pt':
|
| 62 |
+
item_array = torch.stack(self.items)
|
| 63 |
+
elif self.version=='np':
|
| 64 |
+
item_array = np.stack(self.items)
|
| 65 |
+
if num is not None:
|
| 66 |
+
# there better be some items
|
| 67 |
+
assert(len(self.items) >= num)
|
| 68 |
+
|
| 69 |
+
# if there are not that many elements just return however many there are
|
| 70 |
+
if len(self.items) < num:
|
| 71 |
+
return item_array
|
| 72 |
+
else:
|
| 73 |
+
idxs = np.random.randint(len(self.items), size=num)
|
| 74 |
+
return item_array[idxs]
|
| 75 |
+
else:
|
| 76 |
+
return item_array
|
| 77 |
+
|
| 78 |
+
def is_full(self):
|
| 79 |
+
full = len(self.items)==self.pool_size
|
| 80 |
+
return full
|
| 81 |
+
|
| 82 |
+
def empty(self):
|
| 83 |
+
self.items = []
|
| 84 |
+
|
| 85 |
+
def have_min_size(self):
|
| 86 |
+
return len(self.items) >= self.min_size
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def update(self, items):
|
| 90 |
+
for item in items:
|
| 91 |
+
if len(self.items) < self.pool_size:
|
| 92 |
+
# the pool is not full, so let's add this in
|
| 93 |
+
self.items.append(item)
|
| 94 |
+
else:
|
| 95 |
+
# the pool is full
|
| 96 |
+
# pop from the front
|
| 97 |
+
self.items.pop(0)
|
| 98 |
+
# add to the back
|
| 99 |
+
self.items.append(item)
|
| 100 |
+
return self.items
|
utils/py.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob, math
|
| 2 |
+
import numpy as np
|
| 3 |
+
# from scipy import misc
|
| 4 |
+
# from scipy import linalg
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import io
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
EPS = 1e-6
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
XMIN = -64.0 # right (neg is left)
|
| 12 |
+
XMAX = 64.0 # right
|
| 13 |
+
YMIN = -64.0 # down (neg is up)
|
| 14 |
+
YMAX = 64.0 # down
|
| 15 |
+
ZMIN = -64.0 # forward
|
| 16 |
+
ZMAX = 64.0 # forward
|
| 17 |
+
|
| 18 |
+
def print_stats(name, tensor):
|
| 19 |
+
tensor = tensor.astype(np.float32)
|
| 20 |
+
print('%s min = %.2f, mean = %.2f, max = %.2f' % (name, np.min(tensor), np.mean(tensor), np.max(tensor)), tensor.shape)
|
| 21 |
+
|
| 22 |
+
def reduce_masked_mean(x, mask, axis=None, keepdims=False):
|
| 23 |
+
# x and mask are the same shape
|
| 24 |
+
# returns shape-1
|
| 25 |
+
# axis can be a list of axes
|
| 26 |
+
prod = x*mask
|
| 27 |
+
numer = np.sum(prod, axis=axis, keepdims=keepdims)
|
| 28 |
+
denom = EPS+np.sum(mask, axis=axis, keepdims=keepdims)
|
| 29 |
+
mean = numer/denom
|
| 30 |
+
return mean
|
| 31 |
+
|
| 32 |
+
def reduce_masked_sum(x, mask, axis=None, keepdims=False):
|
| 33 |
+
# x and mask are the same shape
|
| 34 |
+
# returns shape-1
|
| 35 |
+
# axis can be a list of axes
|
| 36 |
+
prod = x*mask
|
| 37 |
+
numer = np.sum(prod, axis=axis, keepdims=keepdims)
|
| 38 |
+
return numer
|
| 39 |
+
|
| 40 |
+
def reduce_masked_median(x, mask, keep_batch=False):
|
| 41 |
+
# x and mask are the same shape
|
| 42 |
+
# returns shape-1
|
| 43 |
+
# axis can be a list of axes
|
| 44 |
+
|
| 45 |
+
if not (x.shape == mask.shape):
|
| 46 |
+
print('reduce_masked_median: these shapes should match:', x.shape, mask.shape)
|
| 47 |
+
assert(False)
|
| 48 |
+
# assert(x.shape == mask.shape)
|
| 49 |
+
|
| 50 |
+
B = list(x.shape)[0]
|
| 51 |
+
|
| 52 |
+
if keep_batch:
|
| 53 |
+
x = np.reshape(x, [B, -1])
|
| 54 |
+
mask = np.reshape(mask, [B, -1])
|
| 55 |
+
meds = np.zeros([B], np.float32)
|
| 56 |
+
for b in list(range(B)):
|
| 57 |
+
xb = x[b]
|
| 58 |
+
mb = mask[b]
|
| 59 |
+
if np.sum(mb) > 0:
|
| 60 |
+
xb = xb[mb > 0]
|
| 61 |
+
meds[b] = np.median(xb)
|
| 62 |
+
else:
|
| 63 |
+
meds[b] = np.nan
|
| 64 |
+
return meds
|
| 65 |
+
else:
|
| 66 |
+
x = np.reshape(x, [-1])
|
| 67 |
+
mask = np.reshape(mask, [-1])
|
| 68 |
+
if np.sum(mask) > 0:
|
| 69 |
+
x = x[mask > 0]
|
| 70 |
+
med = np.median(x)
|
| 71 |
+
else:
|
| 72 |
+
med = np.nan
|
| 73 |
+
med = np.array([med], np.float32)
|
| 74 |
+
return med
|
| 75 |
+
|
| 76 |
+
def get_nFiles(path):
|
| 77 |
+
return len(glob.glob(path))
|
| 78 |
+
|
| 79 |
+
def get_file_list(path):
|
| 80 |
+
return glob.glob(path)
|
| 81 |
+
|
| 82 |
+
def rotm2eul(R):
|
| 83 |
+
# R is 3x3
|
| 84 |
+
sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
|
| 85 |
+
if sy > 1e-6: # singular
|
| 86 |
+
x = math.atan2(R[2,1] , R[2,2])
|
| 87 |
+
y = math.atan2(-R[2,0], sy)
|
| 88 |
+
z = math.atan2(R[1,0], R[0,0])
|
| 89 |
+
else:
|
| 90 |
+
x = math.atan2(-R[1,2], R[1,1])
|
| 91 |
+
y = math.atan2(-R[2,0], sy)
|
| 92 |
+
z = 0
|
| 93 |
+
return x, y, z
|
| 94 |
+
|
| 95 |
+
def rad2deg(rad):
|
| 96 |
+
return rad*180.0/np.pi
|
| 97 |
+
|
| 98 |
+
def deg2rad(deg):
|
| 99 |
+
return deg/180.0*np.pi
|
| 100 |
+
|
| 101 |
+
def eul2rotm(rx, ry, rz):
|
| 102 |
+
# copy of matlab, but order of inputs is different
|
| 103 |
+
# R = [ cy*cz sy*sx*cz-sz*cx sy*cx*cz+sz*sx
|
| 104 |
+
# cy*sz sy*sx*sz+cz*cx sy*cx*sz-cz*sx
|
| 105 |
+
# -sy cy*sx cy*cx]
|
| 106 |
+
sinz = np.sin(rz)
|
| 107 |
+
siny = np.sin(ry)
|
| 108 |
+
sinx = np.sin(rx)
|
| 109 |
+
cosz = np.cos(rz)
|
| 110 |
+
cosy = np.cos(ry)
|
| 111 |
+
cosx = np.cos(rx)
|
| 112 |
+
r11 = cosy*cosz
|
| 113 |
+
r12 = sinx*siny*cosz - cosx*sinz
|
| 114 |
+
r13 = cosx*siny*cosz + sinx*sinz
|
| 115 |
+
r21 = cosy*sinz
|
| 116 |
+
r22 = sinx*siny*sinz + cosx*cosz
|
| 117 |
+
r23 = cosx*siny*sinz - sinx*cosz
|
| 118 |
+
r31 = -siny
|
| 119 |
+
r32 = sinx*cosy
|
| 120 |
+
r33 = cosx*cosy
|
| 121 |
+
r1 = np.stack([r11,r12,r13],axis=-1)
|
| 122 |
+
r2 = np.stack([r21,r22,r23],axis=-1)
|
| 123 |
+
r3 = np.stack([r31,r32,r33],axis=-1)
|
| 124 |
+
r = np.stack([r1,r2,r3],axis=0)
|
| 125 |
+
return r
|
| 126 |
+
|
| 127 |
+
def wrap2pi(rad_angle):
|
| 128 |
+
# puts the angle into the range [-pi, pi]
|
| 129 |
+
return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
|
| 130 |
+
|
| 131 |
+
def rot2view(rx,ry,rz,x,y,z):
|
| 132 |
+
# takes rot angles and 3d position as input
|
| 133 |
+
# returns viewpoint angles as output
|
| 134 |
+
# (all in radians)
|
| 135 |
+
# it will perform strangely if z <= 0
|
| 136 |
+
az = wrap2pi(ry - (-np.arctan2(z, x) - 1.5*np.pi))
|
| 137 |
+
el = -wrap2pi(rx - (-np.arctan2(z, y) - 1.5*np.pi))
|
| 138 |
+
th = -rz
|
| 139 |
+
return az, el, th
|
| 140 |
+
|
| 141 |
+
def invAxB(a,b):
|
| 142 |
+
"""
|
| 143 |
+
Compute the relative 3d transformation between a and b.
|
| 144 |
+
|
| 145 |
+
Input:
|
| 146 |
+
a -- first pose (homogeneous 4x4 matrix)
|
| 147 |
+
b -- second pose (homogeneous 4x4 matrix)
|
| 148 |
+
|
| 149 |
+
Output:
|
| 150 |
+
Relative 3d transformation from a to b.
|
| 151 |
+
"""
|
| 152 |
+
return np.dot(np.linalg.inv(a),b)
|
| 153 |
+
|
| 154 |
+
def merge_rt(r, t):
|
| 155 |
+
# r is 3 x 3
|
| 156 |
+
# t is 3 or maybe 3 x 1
|
| 157 |
+
t = np.reshape(t, [3, 1])
|
| 158 |
+
rt = np.concatenate((r,t), axis=1)
|
| 159 |
+
# rt is 3 x 4
|
| 160 |
+
br = np.reshape(np.array([0,0,0,1], np.float32), [1, 4])
|
| 161 |
+
# br is 1 x 4
|
| 162 |
+
rt = np.concatenate((rt, br), axis=0)
|
| 163 |
+
# rt is 4 x 4
|
| 164 |
+
return rt
|
| 165 |
+
|
| 166 |
+
def split_rt(rt):
|
| 167 |
+
r = rt[:3,:3]
|
| 168 |
+
t = rt[:3,3]
|
| 169 |
+
r = np.reshape(r, [3, 3])
|
| 170 |
+
t = np.reshape(t, [3, 1])
|
| 171 |
+
return r, t
|
| 172 |
+
|
| 173 |
+
def split_intrinsics(K):
|
| 174 |
+
# K is 3 x 4 or 4 x 4
|
| 175 |
+
fx = K[0,0]
|
| 176 |
+
fy = K[1,1]
|
| 177 |
+
x0 = K[0,2]
|
| 178 |
+
y0 = K[1,2]
|
| 179 |
+
return fx, fy, x0, y0
|
| 180 |
+
|
| 181 |
+
def merge_intrinsics(fx, fy, x0, y0):
|
| 182 |
+
# inputs are shaped []
|
| 183 |
+
K = np.eye(4)
|
| 184 |
+
K[0,0] = fx
|
| 185 |
+
K[1,1] = fy
|
| 186 |
+
K[0,2] = x0
|
| 187 |
+
K[1,2] = y0
|
| 188 |
+
# K is shaped 4 x 4
|
| 189 |
+
return K
|
| 190 |
+
|
| 191 |
+
def scale_intrinsics(K, sx, sy):
|
| 192 |
+
fx, fy, x0, y0 = split_intrinsics(K)
|
| 193 |
+
fx *= sx
|
| 194 |
+
fy *= sy
|
| 195 |
+
x0 *= sx
|
| 196 |
+
y0 *= sy
|
| 197 |
+
return merge_intrinsics(fx, fy, x0, y0)
|
| 198 |
+
|
| 199 |
+
# def meshgrid(H, W):
|
| 200 |
+
# x = np.linspace(0, W-1, W)
|
| 201 |
+
# y = np.linspace(0, H-1, H)
|
| 202 |
+
# xv, yv = np.meshgrid(x, y)
|
| 203 |
+
# return xv, yv
|
| 204 |
+
|
| 205 |
+
def compute_distance(transform):
|
| 206 |
+
"""
|
| 207 |
+
Compute the distance of the translational component of a 4x4 homogeneous matrix.
|
| 208 |
+
"""
|
| 209 |
+
return numpy.linalg.norm(transform[0:3,3])
|
| 210 |
+
|
| 211 |
+
def radian_l1_dist(e, g):
|
| 212 |
+
# if our angles are in [0, 360] we can follow this stack overflow answer:
|
| 213 |
+
# https://gamedev.stackexchange.com/questions/4467/comparing-angles-and-working-out-the-difference
|
| 214 |
+
# wrap2pi brings the angles to [-180, 180]; adding pi puts them in [0, 360]
|
| 215 |
+
e = wrap2pi(e)+np.pi
|
| 216 |
+
g = wrap2pi(g)+np.pi
|
| 217 |
+
l = np.abs(np.pi - np.abs(np.abs(e-g) - np.pi))
|
| 218 |
+
return l
|
| 219 |
+
|
| 220 |
+
def apply_pix_T_cam(pix_T_cam, xyz):
|
| 221 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
| 222 |
+
# xyz is shaped B x H*W x 3
|
| 223 |
+
# returns xy, shaped B x H*W x 2
|
| 224 |
+
N, C = xyz.shape
|
| 225 |
+
x, y, z = np.split(xyz, 3, axis=-1)
|
| 226 |
+
EPS = 1e-4
|
| 227 |
+
z = np.clip(z, EPS, None)
|
| 228 |
+
x = (x*fx)/(z)+x0
|
| 229 |
+
y = (y*fy)/(z)+y0
|
| 230 |
+
xy = np.concatenate([x, y], axis=-1)
|
| 231 |
+
return xy
|
| 232 |
+
|
| 233 |
+
def apply_4x4(RT, XYZ):
|
| 234 |
+
# RT is 4 x 4
|
| 235 |
+
# XYZ is N x 3
|
| 236 |
+
|
| 237 |
+
# put into homogeneous coords
|
| 238 |
+
X, Y, Z = np.split(XYZ, 3, axis=1)
|
| 239 |
+
ones = np.ones_like(X)
|
| 240 |
+
XYZ1 = np.concatenate([X, Y, Z, ones], axis=1)
|
| 241 |
+
# XYZ1 is N x 4
|
| 242 |
+
|
| 243 |
+
XYZ1_t = np.transpose(XYZ1)
|
| 244 |
+
# this is 4 x N
|
| 245 |
+
|
| 246 |
+
XYZ2_t = np.dot(RT, XYZ1_t)
|
| 247 |
+
# this is 4 x N
|
| 248 |
+
|
| 249 |
+
XYZ2 = np.transpose(XYZ2_t)
|
| 250 |
+
# this is N x 4
|
| 251 |
+
|
| 252 |
+
XYZ2 = XYZ2[:,:3]
|
| 253 |
+
# this is N x 3
|
| 254 |
+
|
| 255 |
+
return XYZ2
|
| 256 |
+
|
| 257 |
+
def Ref2Mem(xyz, Z, Y, X):
|
| 258 |
+
# xyz is N x 3, in ref coordinates
|
| 259 |
+
# transforms ref coordinates into mem coordinates
|
| 260 |
+
N, C = xyz.shape
|
| 261 |
+
assert(C==3)
|
| 262 |
+
mem_T_ref = get_mem_T_ref(Z, Y, X)
|
| 263 |
+
xyz = apply_4x4(mem_T_ref, xyz)
|
| 264 |
+
return xyz
|
| 265 |
+
|
| 266 |
+
# def Mem2Ref(xyz_mem, MH, MW, MD):
|
| 267 |
+
# # xyz is B x N x 3, in mem coordinates
|
| 268 |
+
# # transforms mem coordinates into ref coordinates
|
| 269 |
+
# B, N, C = xyz_mem.get_shape().as_list()
|
| 270 |
+
# ref_T_mem = get_ref_T_mem(B, MH, MW, MD)
|
| 271 |
+
# xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
|
| 272 |
+
# return xyz_ref
|
| 273 |
+
|
| 274 |
+
def get_mem_T_ref(Z, Y, X):
|
| 275 |
+
# sometimes we want the mat itself
|
| 276 |
+
# note this is not a rigid transform
|
| 277 |
+
|
| 278 |
+
# for interpretability, let's construct this in two steps...
|
| 279 |
+
|
| 280 |
+
# translation
|
| 281 |
+
center_T_ref = np.eye(4, dtype=np.float32)
|
| 282 |
+
center_T_ref[0,3] = -XMIN
|
| 283 |
+
center_T_ref[1,3] = -YMIN
|
| 284 |
+
center_T_ref[2,3] = -ZMIN
|
| 285 |
+
|
| 286 |
+
VOX_SIZE_X = (XMAX-XMIN)/float(X)
|
| 287 |
+
VOX_SIZE_Y = (YMAX-YMIN)/float(Y)
|
| 288 |
+
VOX_SIZE_Z = (ZMAX-ZMIN)/float(Z)
|
| 289 |
+
|
| 290 |
+
# scaling
|
| 291 |
+
mem_T_center = np.eye(4, dtype=np.float32)
|
| 292 |
+
mem_T_center[0,0] = 1./VOX_SIZE_X
|
| 293 |
+
mem_T_center[1,1] = 1./VOX_SIZE_Y
|
| 294 |
+
mem_T_center[2,2] = 1./VOX_SIZE_Z
|
| 295 |
+
|
| 296 |
+
mem_T_ref = np.dot(mem_T_center, center_T_ref)
|
| 297 |
+
return mem_T_ref
|
| 298 |
+
|
| 299 |
+
def safe_inverse(a):
|
| 300 |
+
r, t = split_rt(a)
|
| 301 |
+
t = np.reshape(t, [3, 1])
|
| 302 |
+
r_transpose = r.T
|
| 303 |
+
inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
|
| 304 |
+
bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
|
| 305 |
+
inv = np.concatenate([inv, bottom_row], 0)
|
| 306 |
+
return inv
|
| 307 |
+
|
| 308 |
+
def get_ref_T_mem(Z, Y, X):
|
| 309 |
+
mem_T_ref = get_mem_T_ref(X, Y, X)
|
| 310 |
+
# note safe_inverse is inapplicable here,
|
| 311 |
+
# since the transform is nonrigid
|
| 312 |
+
ref_T_mem = np.linalg.inv(mem_T_ref)
|
| 313 |
+
return ref_T_mem
|
| 314 |
+
|
| 315 |
+
def voxelize_xyz(xyz_ref, Z, Y, X):
|
| 316 |
+
# xyz_ref is N x 3
|
| 317 |
+
xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)
|
| 318 |
+
# this is N x 3
|
| 319 |
+
voxels = get_occupancy(xyz_mem, Z, Y, X)
|
| 320 |
+
voxels = np.reshape(voxels, [Z, Y, X, 1])
|
| 321 |
+
return voxels
|
| 322 |
+
|
| 323 |
+
def get_inbounds(xyz, Z, Y, X, already_mem=False):
|
| 324 |
+
# xyz is H*W x 3
|
| 325 |
+
|
| 326 |
+
if not already_mem:
|
| 327 |
+
xyz = Ref2Mem(xyz, Z, Y, X)
|
| 328 |
+
|
| 329 |
+
x_valid = np.logical_and(
|
| 330 |
+
np.greater_equal(xyz[:,0], -0.5),
|
| 331 |
+
np.less(xyz[:,0], float(X)-0.5))
|
| 332 |
+
y_valid = np.logical_and(
|
| 333 |
+
np.greater_equal(xyz[:,1], -0.5),
|
| 334 |
+
np.less(xyz[:,1], float(Y)-0.5))
|
| 335 |
+
z_valid = np.logical_and(
|
| 336 |
+
np.greater_equal(xyz[:,2], -0.5),
|
| 337 |
+
np.less(xyz[:,2], float(Z)-0.5))
|
| 338 |
+
inbounds = np.logical_and(np.logical_and(x_valid, y_valid), z_valid)
|
| 339 |
+
return inbounds
|
| 340 |
+
|
| 341 |
+
def sub2ind3d_zyx(depth, height, width, d, h, w):
|
| 342 |
+
# same as sub2ind3d, but inputs in zyx order
|
| 343 |
+
# when gathering/scattering with these inds, the tensor should be Z x Y x X
|
| 344 |
+
return d*height*width + h*width + w
|
| 345 |
+
|
| 346 |
+
def sub2ind3d_yxz(height, width, depth, h, w, d):
|
| 347 |
+
return h*width*depth + w*depth + d
|
| 348 |
+
|
| 349 |
+
# def ind2sub(height, width, ind):
|
| 350 |
+
# # int input
|
| 351 |
+
# y = int(ind / height)
|
| 352 |
+
# x = ind % height
|
| 353 |
+
# return y, x
|
| 354 |
+
|
| 355 |
+
def get_occupancy(xyz_mem, Z, Y, X):
|
| 356 |
+
# xyz_mem is N x 3
|
| 357 |
+
# we want to fill a voxel tensor with 1's at these inds
|
| 358 |
+
|
| 359 |
+
inbounds = get_inbounds(xyz_mem, Z, Y, X, already_mem=True)
|
| 360 |
+
inds = np.where(inbounds)
|
| 361 |
+
|
| 362 |
+
xyz_mem = np.reshape(xyz_mem[inds], [-1, 3])
|
| 363 |
+
# xyz_mem is N x 3
|
| 364 |
+
|
| 365 |
+
# this is more accurate than a cast/floor, but runs into issues when Y==0
|
| 366 |
+
xyz_mem = np.round(xyz_mem).astype(np.int32)
|
| 367 |
+
x = xyz_mem[:,0]
|
| 368 |
+
y = xyz_mem[:,1]
|
| 369 |
+
z = xyz_mem[:,2]
|
| 370 |
+
|
| 371 |
+
voxels = np.zeros([Z, Y, X], np.float32)
|
| 372 |
+
voxels[z, y, x] = 1.0
|
| 373 |
+
|
| 374 |
+
return voxels
|
| 375 |
+
|
| 376 |
+
def pixels2camera(x,y,z,fx,fy,x0,y0):
|
| 377 |
+
# x and y are locations in pixel coordinates, z is a depth image in meters
|
| 378 |
+
# their shapes are H x W
|
| 379 |
+
# fx, fy, x0, y0 are scalar camera intrinsics
|
| 380 |
+
# returns xyz, sized [B,H*W,3]
|
| 381 |
+
|
| 382 |
+
H, W = z.shape
|
| 383 |
+
|
| 384 |
+
fx = np.reshape(fx, [1,1])
|
| 385 |
+
fy = np.reshape(fy, [1,1])
|
| 386 |
+
x0 = np.reshape(x0, [1,1])
|
| 387 |
+
y0 = np.reshape(y0, [1,1])
|
| 388 |
+
|
| 389 |
+
# unproject
|
| 390 |
+
x = ((z+EPS)/fx)*(x-x0)
|
| 391 |
+
y = ((z+EPS)/fy)*(y-y0)
|
| 392 |
+
|
| 393 |
+
x = np.reshape(x, [-1])
|
| 394 |
+
y = np.reshape(y, [-1])
|
| 395 |
+
z = np.reshape(z, [-1])
|
| 396 |
+
xyz = np.stack([x,y,z], axis=1)
|
| 397 |
+
return xyz
|
| 398 |
+
|
| 399 |
+
def depth2pointcloud(z, pix_T_cam):
|
| 400 |
+
H = z.shape[0]
|
| 401 |
+
W = z.shape[1]
|
| 402 |
+
y, x = meshgrid2d(H, W)
|
| 403 |
+
z = np.reshape(z, [H, W])
|
| 404 |
+
|
| 405 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
| 406 |
+
xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
|
| 407 |
+
return xyz
|
| 408 |
+
|
| 409 |
+
def meshgrid2d(Y, X):
|
| 410 |
+
grid_y = np.linspace(0.0, Y-1, Y)
|
| 411 |
+
grid_y = np.reshape(grid_y, [Y, 1])
|
| 412 |
+
grid_y = np.tile(grid_y, [1, X])
|
| 413 |
+
|
| 414 |
+
grid_x = np.linspace(0.0, X-1, X)
|
| 415 |
+
grid_x = np.reshape(grid_x, [1, X])
|
| 416 |
+
grid_x = np.tile(grid_x, [Y, 1])
|
| 417 |
+
|
| 418 |
+
# outputs are Y x X
|
| 419 |
+
return grid_y, grid_x
|
| 420 |
+
|
| 421 |
+
def gridcloud3d(Y, X, Z):
|
| 422 |
+
x_ = np.linspace(0, X-1, X)
|
| 423 |
+
y_ = np.linspace(0, Y-1, Y)
|
| 424 |
+
z_ = np.linspace(0, Z-1, Z)
|
| 425 |
+
y, x, z = np.meshgrid(y_, x_, z_, indexing='ij')
|
| 426 |
+
x = np.reshape(x, [-1])
|
| 427 |
+
y = np.reshape(y, [-1])
|
| 428 |
+
z = np.reshape(z, [-1])
|
| 429 |
+
xyz = np.stack([x,y,z], axis=1).astype(np.float32)
|
| 430 |
+
return xyz
|
| 431 |
+
|
| 432 |
+
def gridcloud2d(Y, X):
|
| 433 |
+
x_ = np.linspace(0, X-1, X)
|
| 434 |
+
y_ = np.linspace(0, Y-1, Y)
|
| 435 |
+
y, x = np.meshgrid(y_, x_, indexing='ij')
|
| 436 |
+
x = np.reshape(x, [-1])
|
| 437 |
+
y = np.reshape(y, [-1])
|
| 438 |
+
xy = np.stack([x,y], axis=1).astype(np.float32)
|
| 439 |
+
return xy
|
| 440 |
+
|
| 441 |
+
def normalize(im):
|
| 442 |
+
im = im - np.min(im)
|
| 443 |
+
im = im / np.max(im)
|
| 444 |
+
return im
|
| 445 |
+
|
| 446 |
+
def wrap2pi(rad_angle):
|
| 447 |
+
# rad_angle can be any shape
|
| 448 |
+
# puts the angle into the range [-pi, pi]
|
| 449 |
+
return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
|
| 450 |
+
|
| 451 |
+
def convert_occ_to_height(occ):
|
| 452 |
+
Z, Y, X, C = occ.shape
|
| 453 |
+
assert(C==1)
|
| 454 |
+
|
| 455 |
+
height = np.linspace(float(Y), 1.0, Y)
|
| 456 |
+
height = np.reshape(height, [1, Y, 1, 1])
|
| 457 |
+
height = np.max(occ*height, axis=1)/float(Y)
|
| 458 |
+
height = np.reshape(height, [Z, X, C])
|
| 459 |
+
return height
|
| 460 |
+
|
| 461 |
+
def create_depth_image(xy, Z, H, W):
|
| 462 |
+
|
| 463 |
+
# turn the xy coordinates into image inds
|
| 464 |
+
xy = np.round(xy)
|
| 465 |
+
|
| 466 |
+
# lidar reports a sphere of measurements
|
| 467 |
+
# only use the inds that are within the image bounds
|
| 468 |
+
# also, only use forward-pointing depths (Z > 0)
|
| 469 |
+
valid = (xy[:,0] < W-1) & (xy[:,1] < H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (Z[:] > 0)
|
| 470 |
+
|
| 471 |
+
# gather these up
|
| 472 |
+
xy = xy[valid]
|
| 473 |
+
Z = Z[valid]
|
| 474 |
+
|
| 475 |
+
inds = sub2ind(H,W,xy[:,1],xy[:,0])
|
| 476 |
+
depth = np.zeros((H*W), np.float32)
|
| 477 |
+
|
| 478 |
+
for (index, replacement) in zip(inds, Z):
|
| 479 |
+
depth[index] = replacement
|
| 480 |
+
depth[np.where(depth == 0.0)] = 70.0
|
| 481 |
+
depth = np.reshape(depth, [H, W])
|
| 482 |
+
|
| 483 |
+
return depth
|
| 484 |
+
|
| 485 |
+
def vis_depth(depth, maxdepth=80.0, log_vis=True):
|
| 486 |
+
depth[depth<=0.0] = maxdepth
|
| 487 |
+
if log_vis:
|
| 488 |
+
depth = np.log(depth)
|
| 489 |
+
depth = np.clip(depth, 0, np.log(maxdepth))
|
| 490 |
+
else:
|
| 491 |
+
depth = np.clip(depth, 0, maxdepth)
|
| 492 |
+
depth = (depth*255.0).astype(np.uint8)
|
| 493 |
+
return depth
|
| 494 |
+
|
| 495 |
+
def preprocess_color(x):
|
| 496 |
+
return x.astype(np.float32) * 1./255 - 0.5
|
| 497 |
+
|
| 498 |
+
def convert_box_to_ref_T_obj(boxes):
|
| 499 |
+
shape = boxes.shape
|
| 500 |
+
boxes = boxes.reshape(-1,9)
|
| 501 |
+
rots = [eul2rotm(rx,ry,rz)
|
| 502 |
+
for rx,ry,rz in boxes[:,6:]]
|
| 503 |
+
rots = np.stack(rots,axis=0)
|
| 504 |
+
trans = boxes[:,:3]
|
| 505 |
+
ref_T_objs = [merge_rt(rot,tran)
|
| 506 |
+
for rot,tran in zip(rots,trans)]
|
| 507 |
+
ref_T_objs = np.stack(ref_T_objs,axis=0)
|
| 508 |
+
ref_T_objs = ref_T_objs.reshape(shape[:-1]+(4,4))
|
| 509 |
+
ref_T_objs = ref_T_objs.astype(np.float32)
|
| 510 |
+
return ref_T_objs
|
| 511 |
+
|
| 512 |
+
def get_rot_from_delta(delta, yaw_only=False):
|
| 513 |
+
dx = delta[:,0]
|
| 514 |
+
dy = delta[:,1]
|
| 515 |
+
dz = delta[:,2]
|
| 516 |
+
|
| 517 |
+
bot_hyp = np.sqrt(dz**2 + dx**2)
|
| 518 |
+
# top_hyp = np.sqrt(bot_hyp**2 + dy**2)
|
| 519 |
+
|
| 520 |
+
pitch = -np.arctan2(dy, bot_hyp)
|
| 521 |
+
yaw = np.arctan2(dz, dx)
|
| 522 |
+
|
| 523 |
+
if yaw_only:
|
| 524 |
+
rot = [eul2rotm(0,y,0) for y in yaw]
|
| 525 |
+
else:
|
| 526 |
+
rot = [eul2rotm(0,y,p) for (p,y) in zip(pitch,yaw)]
|
| 527 |
+
|
| 528 |
+
rot = np.stack(rot)
|
| 529 |
+
# rot is B x 3 x 3
|
| 530 |
+
return rot
|
| 531 |
+
|
| 532 |
+
def im2col(im, psize):
|
| 533 |
+
n_channels = 1 if len(im.shape) == 2 else im.shape[0]
|
| 534 |
+
(n_channels, rows, cols) = (1,) * (3 - len(im.shape)) + im.shape
|
| 535 |
+
|
| 536 |
+
im_pad = np.zeros((n_channels,
|
| 537 |
+
int(math.ceil(1.0 * rows / psize) * psize),
|
| 538 |
+
int(math.ceil(1.0 * cols / psize) * psize)))
|
| 539 |
+
im_pad[:, 0:rows, 0:cols] = im
|
| 540 |
+
|
| 541 |
+
final = np.zeros((im_pad.shape[1], im_pad.shape[2], n_channels,
|
| 542 |
+
psize, psize))
|
| 543 |
+
for c in np.arange(n_channels):
|
| 544 |
+
for x in np.arange(psize):
|
| 545 |
+
for y in np.arange(psize):
|
| 546 |
+
im_shift = np.vstack(
|
| 547 |
+
(im_pad[c, x:], im_pad[c, :x]))
|
| 548 |
+
im_shift = np.column_stack(
|
| 549 |
+
(im_shift[:, y:], im_shift[:, :y]))
|
| 550 |
+
final[x::psize, y::psize, c] = np.swapaxes(
|
| 551 |
+
im_shift.reshape(int(im_pad.shape[1] / psize), psize,
|
| 552 |
+
int(im_pad.shape[2] / psize), psize), 1, 2)
|
| 553 |
+
|
| 554 |
+
return np.squeeze(final[0:rows - psize + 1, 0:cols - psize + 1])
|
| 555 |
+
|
| 556 |
+
def filter_discontinuities(depth, filter_size=9, thresh=10):
|
| 557 |
+
H, W = list(depth.shape)
|
| 558 |
+
|
| 559 |
+
# Ensure that filter sizes are okay
|
| 560 |
+
assert filter_size % 2 == 1, "Can only use odd filter sizes."
|
| 561 |
+
|
| 562 |
+
# Compute discontinuities
|
| 563 |
+
offset = int((filter_size - 1) / 2)
|
| 564 |
+
patches = 1.0 * im2col(depth, filter_size)
|
| 565 |
+
mids = patches[:, :, offset, offset]
|
| 566 |
+
mins = np.min(patches, axis=(2, 3))
|
| 567 |
+
maxes = np.max(patches, axis=(2, 3))
|
| 568 |
+
|
| 569 |
+
discont = np.maximum(np.abs(mins - mids),
|
| 570 |
+
np.abs(maxes - mids))
|
| 571 |
+
mark = discont > thresh
|
| 572 |
+
|
| 573 |
+
# Account for offsets
|
| 574 |
+
final_mark = np.zeros((H, W), dtype=np.uint16)
|
| 575 |
+
final_mark[offset:offset + mark.shape[0],
|
| 576 |
+
offset:offset + mark.shape[1]] = mark
|
| 577 |
+
|
| 578 |
+
return depth * (1 - final_mark)
|
| 579 |
+
|
| 580 |
+
def argmax2d(tensor):
|
| 581 |
+
Y, X = list(tensor.shape)
|
| 582 |
+
# flatten the Tensor along the height and width axes
|
| 583 |
+
flat_tensor = tensor.reshape(-1)
|
| 584 |
+
# argmax of the flat tensor
|
| 585 |
+
argmax = np.argmax(flat_tensor)
|
| 586 |
+
|
| 587 |
+
# convert the indices into 2d coordinates
|
| 588 |
+
argmax_y = argmax // X # row
|
| 589 |
+
argmax_x = argmax % X # col
|
| 590 |
+
|
| 591 |
+
return argmax_y, argmax_x
|
| 592 |
+
|
| 593 |
+
def plot_traj_3d(traj):
|
| 594 |
+
# traj is S x 3
|
| 595 |
+
|
| 596 |
+
# print('traj', traj.shape)
|
| 597 |
+
S, C = list(traj.shape)
|
| 598 |
+
assert(C==3)
|
| 599 |
+
|
| 600 |
+
fig = plt.figure()
|
| 601 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 602 |
+
|
| 603 |
+
colors = [plt.cm.RdYlBu(i) for i in np.linspace(0,1,S)]
|
| 604 |
+
# print('colors', colors)
|
| 605 |
+
|
| 606 |
+
xs = traj[:,0]
|
| 607 |
+
ys = -traj[:,1]
|
| 608 |
+
zs = traj[:,2]
|
| 609 |
+
|
| 610 |
+
ax.scatter(xs, zs, ys, s=30, c=colors, marker='o', alpha=1.0, edgecolors=(0,0,0))#, color=color_map[n])
|
| 611 |
+
|
| 612 |
+
ax.set_xlabel('X')
|
| 613 |
+
ax.set_ylabel('Z')
|
| 614 |
+
ax.set_zlabel('Y')
|
| 615 |
+
|
| 616 |
+
ax.set_xlim(0,1)
|
| 617 |
+
ax.set_ylim(0,1) # this is really Z
|
| 618 |
+
ax.set_zlim(-1,0) # this is really Y
|
| 619 |
+
|
| 620 |
+
buf = io.BytesIO()
|
| 621 |
+
plt.savefig(buf, format='png')
|
| 622 |
+
buf.seek(0)
|
| 623 |
+
image = np.array(Image.open(buf)) # H x W x 4
|
| 624 |
+
image = image[:,:,:3]
|
| 625 |
+
|
| 626 |
+
plt.close()
|
| 627 |
+
return image
|
| 628 |
+
|
| 629 |
+
def camera2pixels(xyz, pix_T_cam):
|
| 630 |
+
# xyz is shaped N x 3
|
| 631 |
+
# returns xy, shaped N x 2
|
| 632 |
+
|
| 633 |
+
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
|
| 634 |
+
x, y, z = xyz[:,0], xyz[:,1], xyz[:,2]
|
| 635 |
+
|
| 636 |
+
EPS = 1e-4
|
| 637 |
+
z = np.clip(z, EPS, None)
|
| 638 |
+
x = (x*fx)/z + x0
|
| 639 |
+
y = (y*fy)/z + y0
|
| 640 |
+
xy = np.stack([x, y], axis=-1)
|
| 641 |
+
return xy
|
| 642 |
+
|
| 643 |
+
def make_colorwheel():
|
| 644 |
+
"""
|
| 645 |
+
Generates a color wheel for optical flow visualization as presented in:
|
| 646 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
| 647 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
| 648 |
+
|
| 649 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
| 650 |
+
Code follows the the Matlab source code of Deqing Sun.
|
| 651 |
+
|
| 652 |
+
Returns:
|
| 653 |
+
np.ndarray: Color wheel
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
RY = 15
|
| 657 |
+
YG = 6
|
| 658 |
+
GC = 4
|
| 659 |
+
CB = 11
|
| 660 |
+
BM = 13
|
| 661 |
+
MR = 6
|
| 662 |
+
|
| 663 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 664 |
+
colorwheel = np.zeros((ncols, 3))
|
| 665 |
+
col = 0
|
| 666 |
+
|
| 667 |
+
# RY
|
| 668 |
+
colorwheel[0:RY, 0] = 255
|
| 669 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
| 670 |
+
col = col+RY
|
| 671 |
+
# YG
|
| 672 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
| 673 |
+
colorwheel[col:col+YG, 1] = 255
|
| 674 |
+
col = col+YG
|
| 675 |
+
# GC
|
| 676 |
+
colorwheel[col:col+GC, 1] = 255
|
| 677 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
| 678 |
+
col = col+GC
|
| 679 |
+
# CB
|
| 680 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
| 681 |
+
colorwheel[col:col+CB, 2] = 255
|
| 682 |
+
col = col+CB
|
| 683 |
+
# BM
|
| 684 |
+
colorwheel[col:col+BM, 2] = 255
|
| 685 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
| 686 |
+
col = col+BM
|
| 687 |
+
# MR
|
| 688 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
| 689 |
+
colorwheel[col:col+MR, 0] = 255
|
| 690 |
+
return colorwheel
|
| 691 |
+
|
| 692 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
| 693 |
+
"""
|
| 694 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
| 695 |
+
|
| 696 |
+
According to the C++ source code of Daniel Scharstein
|
| 697 |
+
According to the Matlab source code of Deqing Sun
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
| 701 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
| 702 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 703 |
+
|
| 704 |
+
Returns:
|
| 705 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 706 |
+
"""
|
| 707 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
| 708 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
| 709 |
+
ncols = colorwheel.shape[0]
|
| 710 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 711 |
+
a = np.arctan2(-v, -u)/np.pi
|
| 712 |
+
fk = (a+1) / 2*(ncols-1)
|
| 713 |
+
k0 = np.floor(fk).astype(np.int32)
|
| 714 |
+
k1 = k0 + 1
|
| 715 |
+
k1[k1 == ncols] = 0
|
| 716 |
+
f = fk - k0
|
| 717 |
+
for i in range(colorwheel.shape[1]):
|
| 718 |
+
tmp = colorwheel[:,i]
|
| 719 |
+
col0 = tmp[k0] / 255.0
|
| 720 |
+
col1 = tmp[k1] / 255.0
|
| 721 |
+
col = (1-f)*col0 + f*col1
|
| 722 |
+
idx = (rad <= 1)
|
| 723 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
| 724 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
| 725 |
+
# Note the 2-i => BGR instead of RGB
|
| 726 |
+
ch_idx = 2-i if convert_to_bgr else i
|
| 727 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
| 728 |
+
return flow_image
|
| 729 |
+
|
| 730 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
| 731 |
+
"""
|
| 732 |
+
Expects a two dimensional flow image of shape.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
| 736 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
| 737 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 738 |
+
|
| 739 |
+
Returns:
|
| 740 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 741 |
+
"""
|
| 742 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
| 743 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
| 744 |
+
if clip_flow is not None:
|
| 745 |
+
flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) / clip_flow
|
| 746 |
+
# flow_uv = np.clamp(flow, -clip, clip)/clip
|
| 747 |
+
|
| 748 |
+
u = flow_uv[:,:,0]
|
| 749 |
+
v = flow_uv[:,:,1]
|
| 750 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 751 |
+
rad_max = np.max(rad)
|
| 752 |
+
epsilon = 1e-5
|
| 753 |
+
u = u / (rad_max + epsilon)
|
| 754 |
+
v = v / (rad_max + epsilon)
|
| 755 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
utils/samp.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import utils.basic
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
| 6 |
+
r"""Sample a tensor using bilinear interpolation
|
| 7 |
+
|
| 8 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 9 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 10 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 11 |
+
convention.
|
| 12 |
+
|
| 13 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 14 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 15 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 16 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 17 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 18 |
+
|
| 19 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 20 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 21 |
+
that in this case the order of the components is slightly different
|
| 22 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 23 |
+
|
| 24 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 25 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 26 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 27 |
+
pixel.
|
| 28 |
+
|
| 29 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 30 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 31 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 32 |
+
pixel.
|
| 33 |
+
|
| 34 |
+
Similar conventions apply to the :math:`y` for the range
|
| 35 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 36 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
input (Tensor): batch of input images.
|
| 40 |
+
coords (Tensor): batch of coordinates.
|
| 41 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 42 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Tensor: sampled points.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
sizes = input.shape[2:]
|
| 49 |
+
|
| 50 |
+
assert len(sizes) in [2, 3]
|
| 51 |
+
|
| 52 |
+
if len(sizes) == 3:
|
| 53 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 54 |
+
coords = coords[..., [1, 2, 0]]
|
| 55 |
+
|
| 56 |
+
if align_corners:
|
| 57 |
+
coords = coords * torch.tensor(
|
| 58 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
coords = coords * torch.tensor(
|
| 62 |
+
[2 / size for size in reversed(sizes)], device=coords.device
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
coords -= 1
|
| 66 |
+
|
| 67 |
+
return F.grid_sample(
|
| 68 |
+
input, coords, align_corners=align_corners, padding_mode=padding_mode
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def sample_features4d(input, coords):
|
| 73 |
+
r"""Sample spatial features
|
| 74 |
+
|
| 75 |
+
`sample_features4d(input, coords)` samples the spatial features
|
| 76 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
| 77 |
+
|
| 78 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
| 79 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
| 80 |
+
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
| 81 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
| 82 |
+
|
| 83 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
| 84 |
+
R, C)`.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
input (Tensor): spatial features.
|
| 88 |
+
coords (Tensor): points.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Tensor: sampled features.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
B, _, _, _ = input.shape
|
| 95 |
+
|
| 96 |
+
# B R 2 -> B R 1 2
|
| 97 |
+
coords = coords.unsqueeze(2)
|
| 98 |
+
|
| 99 |
+
# B C R 1
|
| 100 |
+
feats = bilinear_sampler(input, coords)
|
| 101 |
+
|
| 102 |
+
return feats.permute(0, 2, 1, 3).view(
|
| 103 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
| 104 |
+
) # B C R 1 -> B R C
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def sample_features5d(input, coords):
|
| 108 |
+
r"""Sample spatio-temporal features
|
| 109 |
+
|
| 110 |
+
`sample_features5d(input, coords)` works in the same way as
|
| 111 |
+
:func:`sample_features4d` but for spatio-temporal features and points:
|
| 112 |
+
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
| 113 |
+
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
| 114 |
+
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
input (Tensor): spatio-temporal features.
|
| 118 |
+
coords (Tensor): spatio-temporal points.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tensor: sampled features.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
B, T, _, _, _ = input.shape
|
| 125 |
+
|
| 126 |
+
# B T C H W -> B C T H W
|
| 127 |
+
input = input.permute(0, 2, 1, 3, 4)
|
| 128 |
+
|
| 129 |
+
# B R1 R2 3 -> B R1 R2 1 3
|
| 130 |
+
coords = coords.unsqueeze(3)
|
| 131 |
+
|
| 132 |
+
# B C R1 R2 1
|
| 133 |
+
feats = bilinear_sampler(input, coords)
|
| 134 |
+
|
| 135 |
+
return feats.permute(0, 2, 3, 1, 4).view(
|
| 136 |
+
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
| 137 |
+
) # B C R1 R2 1 -> B R1 R2 C
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
| 141 |
+
# x and y are each B, N
|
| 142 |
+
# output is B, C, N
|
| 143 |
+
B, C, H, W = list(im.shape)
|
| 144 |
+
N = list(x.shape)[1]
|
| 145 |
+
|
| 146 |
+
x = x.float()
|
| 147 |
+
y = y.float()
|
| 148 |
+
H_f = torch.tensor(H, dtype=torch.float32)
|
| 149 |
+
W_f = torch.tensor(W, dtype=torch.float32)
|
| 150 |
+
|
| 151 |
+
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
| 152 |
+
|
| 153 |
+
max_y = (H_f - 1).int()
|
| 154 |
+
max_x = (W_f - 1).int()
|
| 155 |
+
|
| 156 |
+
x0 = torch.floor(x).int()
|
| 157 |
+
x1 = x0 + 1
|
| 158 |
+
y0 = torch.floor(y).int()
|
| 159 |
+
y1 = y0 + 1
|
| 160 |
+
|
| 161 |
+
x0_clip = torch.clamp(x0, 0, max_x)
|
| 162 |
+
x1_clip = torch.clamp(x1, 0, max_x)
|
| 163 |
+
y0_clip = torch.clamp(y0, 0, max_y)
|
| 164 |
+
y1_clip = torch.clamp(y1, 0, max_y)
|
| 165 |
+
dim2 = W
|
| 166 |
+
dim1 = W * H
|
| 167 |
+
|
| 168 |
+
base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
|
| 169 |
+
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
| 170 |
+
|
| 171 |
+
base_y0 = base + y0_clip * dim2
|
| 172 |
+
base_y1 = base + y1_clip * dim2
|
| 173 |
+
|
| 174 |
+
idx_y0_x0 = base_y0 + x0_clip
|
| 175 |
+
idx_y0_x1 = base_y0 + x1_clip
|
| 176 |
+
idx_y1_x0 = base_y1 + x0_clip
|
| 177 |
+
idx_y1_x1 = base_y1 + x1_clip
|
| 178 |
+
|
| 179 |
+
# use the indices to lookup pixels in the flat image
|
| 180 |
+
# im is B x C x H x W
|
| 181 |
+
# move C out to last dim
|
| 182 |
+
im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
|
| 183 |
+
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
| 184 |
+
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
| 185 |
+
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
| 186 |
+
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
| 187 |
+
|
| 188 |
+
# Finally calculate interpolated values.
|
| 189 |
+
x0_f = x0.float()
|
| 190 |
+
x1_f = x1.float()
|
| 191 |
+
y0_f = y0.float()
|
| 192 |
+
y1_f = y1.float()
|
| 193 |
+
|
| 194 |
+
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
| 195 |
+
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
| 196 |
+
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
| 197 |
+
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
| 198 |
+
|
| 199 |
+
output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
|
| 200 |
+
w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
| 201 |
+
# output is B*N x C
|
| 202 |
+
output = output.view(B, -1, C)
|
| 203 |
+
output = output.permute(0, 2, 1)
|
| 204 |
+
# output is B x C x N
|
| 205 |
+
|
| 206 |
+
if return_inbounds:
|
| 207 |
+
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
| 208 |
+
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
| 209 |
+
inbounds = (x_valid & y_valid).float()
|
| 210 |
+
inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
| 211 |
+
return output, inbounds
|
| 212 |
+
|
| 213 |
+
return output # B, C, N
|
utils/saveload.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
|
| 6 |
+
pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
|
| 7 |
+
prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
|
| 8 |
+
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
|
| 9 |
+
if len(prev_ckpts) > keep_latest-1:
|
| 10 |
+
for f in prev_ckpts[keep_latest-1:]:
|
| 11 |
+
f.unlink()
|
| 12 |
+
save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
|
| 13 |
+
save_dict = {
|
| 14 |
+
"model": module.state_dict(),
|
| 15 |
+
"optimizer": optimizer.state_dict(),
|
| 16 |
+
"global_step": global_step,
|
| 17 |
+
}
|
| 18 |
+
if scheduler is not None:
|
| 19 |
+
save_dict['scheduler'] = scheduler.state_dict()
|
| 20 |
+
print(f"saving {save_path}")
|
| 21 |
+
torch.save(save_dict, save_path)
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
|
| 25 |
+
if verbose:
|
| 26 |
+
print('reading ckpt from %s' % ckpt_path)
|
| 27 |
+
if not os.path.exists(ckpt_path):
|
| 28 |
+
print('...there is no full checkpoint in %s' % ckpt_path)
|
| 29 |
+
print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --')
|
| 30 |
+
assert(False)
|
| 31 |
+
else:
|
| 32 |
+
if os.path.isfile(ckpt_path):
|
| 33 |
+
path = ckpt_path
|
| 34 |
+
print('...found checkpoint %s' % (path))
|
| 35 |
+
else:
|
| 36 |
+
prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name))
|
| 37 |
+
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
|
| 38 |
+
if len(prev_ckpts):
|
| 39 |
+
path = prev_ckpts[0]
|
| 40 |
+
# e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
|
| 41 |
+
# OR ./whatever.pth
|
| 42 |
+
step = int(str(path).split('-')[-1].split('.')[0])
|
| 43 |
+
if verbose:
|
| 44 |
+
print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
|
| 45 |
+
else:
|
| 46 |
+
print('...there is no full checkpoint here!')
|
| 47 |
+
return 0
|
| 48 |
+
if fabric is not None:
|
| 49 |
+
checkpoint = fabric.load(path)
|
| 50 |
+
else:
|
| 51 |
+
checkpoint = torch.load(path, weights_only=weights_only)
|
| 52 |
+
if optimizer is not None:
|
| 53 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 54 |
+
if scheduler is not None:
|
| 55 |
+
scheduler.load_state_dict(checkpoint['scheduler'])
|
| 56 |
+
assert ignore_load is None # not ready yet
|
| 57 |
+
if 'model' in checkpoint:
|
| 58 |
+
state_dict = checkpoint['model']
|
| 59 |
+
else:
|
| 60 |
+
state_dict = checkpoint
|
| 61 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 62 |
+
return step
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|