rayst3r / models /rayquery.py
bartduis's picture
init
70d1188
raw
history blame
8.95 kB
bb = breakpoint
import torch
import torch.nn as nn
from models.blocks import DecoderBlock, Block, PatchEmbed, PositionGetter
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
from models.losses import *
from utils.geometry import center_pointmaps, compute_rays
from models.heads import head_factory
def init_weights(m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Parameter):
nn.init.normal_(m, std=0.02)
class RayEncoder(nn.Module):
def __init__(self,
dim=256,
patch_size=8,
img_size=(128,128),
depth=3,
num_heads=4,
pos_embed='RoPE100',
):
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=2, embed_dim=dim)
self.dim = dim
if pos_embed.startswith('RoPE'):
freq = float(pos_embed[len('RoPE'):])
self.rope = RoPE2D(freq=freq)
else:
self.rope = None
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
self.initialize_weights()
def initialize_weights(self):
# patch embed
self.patch_embed._init_weights()
# linears and layer norms
self.apply(init_weights)
def forward(self, rays):
rays = rays.permute(0,3,1,2)
rays, pos = self.patch_embed(rays)
for blk in self.blocks:
rays = blk(rays, pos)
return rays, pos
class PointmapEncoder(nn.Module):
def __init__(self,
dim=256,
patch_size=8,
img_size=(128,128),
depth=3,
num_heads=4,
pos_embed='RoPE100',
):
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=3, embed_dim=dim)
self.dim = dim
self.patch_size = patch_size
if pos_embed.startswith('RoPE'):
freq = float(pos_embed[len('RoPE'):])
self.rope = RoPE2D(freq=freq)
else:
self.rope = None
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
self.masked_token = nn.Parameter(torch.randn(1,1,3))
self.initialize_weights()
def initialize_weights(self):
# patch embed
self.patch_embed._init_weights()
# linears and layer norms
self.apply(init_weights)
def forward(self, pointmaps,masks=None):
# replace masked points (not on object) with a learned token
pointmaps[~masks] = self.masked_token.to(pointmaps.dtype).to(pointmaps.device)
pointmaps = pointmaps.permute(0,3,1,2)
pointmaps, pos = self.patch_embed(pointmaps)
for blk in self.blocks:
pointmaps = blk(pointmaps, pos)
return pointmaps, pos
class RayQuery(nn.Module):
def __init__(self,
ray_enc=RayEncoder(),
pointmap_enc=PointmapEncoder(),
dec_pos_embed='RoPE100',
decoder_dim=256,
decoder_depth=3,
decoder_num_heads=4,
imshape=(128,128),
pts_head_type='dpt',
classifier_head_type='dpt_mask',
criterion=ConfLoss(L21),
return_all_blocks=True,
depth_mode=('exp',-float('inf'),float('inf')),
conf_mode=('exp',1,float('inf')),
classifier_mode=('raw',0,1),
dino_layers=[23],
):
super().__init__()
self.ray_enc = ray_enc
self.pointmap_enc = pointmap_enc
self.dec_depth = decoder_depth
self.dec_embed_dim = decoder_dim
self.enc_embed_dim = ray_enc.dim
self.patch_size = pointmap_enc.patch_size
self.depth_mode = depth_mode
self.conf_mode = conf_mode
self.classifier_mode = classifier_mode
self.skip_dino = len(dino_layers) == 0
self.pts_head_type = pts_head_type
self.classifier_head_type = classifier_head_type
if dec_pos_embed.startswith('RoPE'):
self.dec_pos_embed = RoPE2D(freq=100.0)
else:
raise NotImplementedError(f'{dec_pos_embed} not implemented')
self.decoder_blocks = nn.ModuleList([DecoderBlock(dim=decoder_dim, num_heads=decoder_num_heads,
rope=self.dec_pos_embed) for _ in range(decoder_depth)])
self.pts_head = head_factory(pts_head_type, 'pts3d', self, has_conf=True)
self.classifier_head = head_factory(classifier_head_type, 'pts3d', self, has_conf=True)
self.imshape = imshape
self.criterion = criterion
self.return_all_blocks = return_all_blocks
# dino projection
self.dino_layers = dino_layers
self.dino_proj = nn.Linear(1024 * len(dino_layers), decoder_dim)
self.dino_pos_getter = PositionGetter()
self.initialize_weights()
def initialize_weights(self):
self.apply(init_weights)
def forward_encoders(self, rays, pointmaps,masks=None):
# encode rays
rays, rays_pos = self.ray_enc(rays)
# encode pointmaps
B, H, W, C = pointmaps.shape
pointmaps = pointmaps.reshape(B,H,W,C) # each pointmap is encoded separately
pointmaps, pointmaps_pos = self.pointmap_enc(pointmaps,masks=masks)
new_shape = pointmaps.shape
pointmaps = pointmaps.reshape(new_shape[0],*new_shape[1:])
pointmaps_pos = pointmaps_pos[:B]
return rays, rays_pos, pointmaps, pointmaps_pos
def forward_decoder(self, rays, rays_pos, pointmaps, pointmaps_pos):
if self.return_all_blocks:
all_blocks = []
for blk in self.decoder_blocks:
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
all_blocks.append(rays)
return all_blocks
else:
for blk in self.decoder_blocks:
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
return rays
def get_dino_pos(self,dino_features):
# dino runs on 14x14 patches
# note: assuming we cropped or resized down!
dino_H = self.imshape[0]//14
dino_W = self.imshape[1]//14
dino_pos = self.dino_pos_getter(dino_features.shape[0],dino_H,dino_W,dino_features.device)
return dino_pos
def forward(self,batch,mode='loss'):
# prep for encoders
rays = compute_rays(batch) # we are querying the first camera
pointmaps_context = batch['input_cams']['pointmaps'] # we are using the other cameras as context
input_masks = batch['input_cams']['valid_masks']
# run the encoders
rays, rays_pos, pointmaps, pointmaps_pos = self.forward_encoders(rays, pointmaps_context,masks=input_masks)
## adding dino features
if not self.skip_dino:
dino_features = batch['input_cams']['dino_features']
dino_features = self.dino_proj(dino_features)
if len(dino_features.shape) == 4:
dino_features = dino_features.squeeze(1)
dino_pos = self.get_dino_pos(dino_features)
pointmaps = torch.cat([pointmaps,dino_features],dim=1)
pointmaps_pos = torch.cat([pointmaps_pos,dino_pos],dim=1)
else:
dino_features = None
dino_pos = None
# decoder
rays = self.forward_decoder(rays, rays_pos, pointmaps, pointmaps_pos)
pts_pred_dict = self.pts_head(rays, self.imshape)
classifier_pred_dict = self.classifier_head(rays, self.imshape)
pred_dict = {**pts_pred_dict,**classifier_pred_dict}
gt_dict = batch['new_cams']
loss_dict = self.criterion(pred_dict, gt_dict)
del rays, rays_pos, pointmaps, pointmaps_pos, dino_features, dino_pos, pointmaps_context, input_masks, pts_pred_dict, classifier_pred_dict
if mode == 'loss':
# delete all the variables that are not needed
del pred_dict, gt_dict
return loss_dict
elif mode == 'viz':
return pred_dict, gt_dict, loss_dict
else:
raise ValueError(f"Invalid mode: {mode}")