File size: 8,947 Bytes
70d1188 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
bb = breakpoint
import torch
import torch.nn as nn
from models.blocks import DecoderBlock, Block, PatchEmbed, PositionGetter
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
from models.losses import *
from utils.geometry import center_pointmaps, compute_rays
from models.heads import head_factory
def init_weights(m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Parameter):
nn.init.normal_(m, std=0.02)
class RayEncoder(nn.Module):
def __init__(self,
dim=256,
patch_size=8,
img_size=(128,128),
depth=3,
num_heads=4,
pos_embed='RoPE100',
):
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=2, embed_dim=dim)
self.dim = dim
if pos_embed.startswith('RoPE'):
freq = float(pos_embed[len('RoPE'):])
self.rope = RoPE2D(freq=freq)
else:
self.rope = None
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
self.initialize_weights()
def initialize_weights(self):
# patch embed
self.patch_embed._init_weights()
# linears and layer norms
self.apply(init_weights)
def forward(self, rays):
rays = rays.permute(0,3,1,2)
rays, pos = self.patch_embed(rays)
for blk in self.blocks:
rays = blk(rays, pos)
return rays, pos
class PointmapEncoder(nn.Module):
def __init__(self,
dim=256,
patch_size=8,
img_size=(128,128),
depth=3,
num_heads=4,
pos_embed='RoPE100',
):
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=3, embed_dim=dim)
self.dim = dim
self.patch_size = patch_size
if pos_embed.startswith('RoPE'):
freq = float(pos_embed[len('RoPE'):])
self.rope = RoPE2D(freq=freq)
else:
self.rope = None
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
self.masked_token = nn.Parameter(torch.randn(1,1,3))
self.initialize_weights()
def initialize_weights(self):
# patch embed
self.patch_embed._init_weights()
# linears and layer norms
self.apply(init_weights)
def forward(self, pointmaps,masks=None):
# replace masked points (not on object) with a learned token
pointmaps[~masks] = self.masked_token.to(pointmaps.dtype).to(pointmaps.device)
pointmaps = pointmaps.permute(0,3,1,2)
pointmaps, pos = self.patch_embed(pointmaps)
for blk in self.blocks:
pointmaps = blk(pointmaps, pos)
return pointmaps, pos
class RayQuery(nn.Module):
def __init__(self,
ray_enc=RayEncoder(),
pointmap_enc=PointmapEncoder(),
dec_pos_embed='RoPE100',
decoder_dim=256,
decoder_depth=3,
decoder_num_heads=4,
imshape=(128,128),
pts_head_type='dpt',
classifier_head_type='dpt_mask',
criterion=ConfLoss(L21),
return_all_blocks=True,
depth_mode=('exp',-float('inf'),float('inf')),
conf_mode=('exp',1,float('inf')),
classifier_mode=('raw',0,1),
dino_layers=[23],
):
super().__init__()
self.ray_enc = ray_enc
self.pointmap_enc = pointmap_enc
self.dec_depth = decoder_depth
self.dec_embed_dim = decoder_dim
self.enc_embed_dim = ray_enc.dim
self.patch_size = pointmap_enc.patch_size
self.depth_mode = depth_mode
self.conf_mode = conf_mode
self.classifier_mode = classifier_mode
self.skip_dino = len(dino_layers) == 0
self.pts_head_type = pts_head_type
self.classifier_head_type = classifier_head_type
if dec_pos_embed.startswith('RoPE'):
self.dec_pos_embed = RoPE2D(freq=100.0)
else:
raise NotImplementedError(f'{dec_pos_embed} not implemented')
self.decoder_blocks = nn.ModuleList([DecoderBlock(dim=decoder_dim, num_heads=decoder_num_heads,
rope=self.dec_pos_embed) for _ in range(decoder_depth)])
self.pts_head = head_factory(pts_head_type, 'pts3d', self, has_conf=True)
self.classifier_head = head_factory(classifier_head_type, 'pts3d', self, has_conf=True)
self.imshape = imshape
self.criterion = criterion
self.return_all_blocks = return_all_blocks
# dino projection
self.dino_layers = dino_layers
self.dino_proj = nn.Linear(1024 * len(dino_layers), decoder_dim)
self.dino_pos_getter = PositionGetter()
self.initialize_weights()
def initialize_weights(self):
self.apply(init_weights)
def forward_encoders(self, rays, pointmaps,masks=None):
# encode rays
rays, rays_pos = self.ray_enc(rays)
# encode pointmaps
B, H, W, C = pointmaps.shape
pointmaps = pointmaps.reshape(B,H,W,C) # each pointmap is encoded separately
pointmaps, pointmaps_pos = self.pointmap_enc(pointmaps,masks=masks)
new_shape = pointmaps.shape
pointmaps = pointmaps.reshape(new_shape[0],*new_shape[1:])
pointmaps_pos = pointmaps_pos[:B]
return rays, rays_pos, pointmaps, pointmaps_pos
def forward_decoder(self, rays, rays_pos, pointmaps, pointmaps_pos):
if self.return_all_blocks:
all_blocks = []
for blk in self.decoder_blocks:
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
all_blocks.append(rays)
return all_blocks
else:
for blk in self.decoder_blocks:
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
return rays
def get_dino_pos(self,dino_features):
# dino runs on 14x14 patches
# note: assuming we cropped or resized down!
dino_H = self.imshape[0]//14
dino_W = self.imshape[1]//14
dino_pos = self.dino_pos_getter(dino_features.shape[0],dino_H,dino_W,dino_features.device)
return dino_pos
def forward(self,batch,mode='loss'):
# prep for encoders
rays = compute_rays(batch) # we are querying the first camera
pointmaps_context = batch['input_cams']['pointmaps'] # we are using the other cameras as context
input_masks = batch['input_cams']['valid_masks']
# run the encoders
rays, rays_pos, pointmaps, pointmaps_pos = self.forward_encoders(rays, pointmaps_context,masks=input_masks)
## adding dino features
if not self.skip_dino:
dino_features = batch['input_cams']['dino_features']
dino_features = self.dino_proj(dino_features)
if len(dino_features.shape) == 4:
dino_features = dino_features.squeeze(1)
dino_pos = self.get_dino_pos(dino_features)
pointmaps = torch.cat([pointmaps,dino_features],dim=1)
pointmaps_pos = torch.cat([pointmaps_pos,dino_pos],dim=1)
else:
dino_features = None
dino_pos = None
# decoder
rays = self.forward_decoder(rays, rays_pos, pointmaps, pointmaps_pos)
pts_pred_dict = self.pts_head(rays, self.imshape)
classifier_pred_dict = self.classifier_head(rays, self.imshape)
pred_dict = {**pts_pred_dict,**classifier_pred_dict}
gt_dict = batch['new_cams']
loss_dict = self.criterion(pred_dict, gt_dict)
del rays, rays_pos, pointmaps, pointmaps_pos, dino_features, dino_pos, pointmaps_context, input_masks, pts_pred_dict, classifier_pred_dict
if mode == 'loss':
# delete all the variables that are not needed
del pred_dict, gt_dict
return loss_dict
elif mode == 'viz':
return pred_dict, gt_dict, loss_dict
else:
raise ValueError(f"Invalid mode: {mode}") |