|
bb = breakpoint |
|
import torch |
|
import torch.nn as nn |
|
from models.blocks import DecoderBlock, Block, PatchEmbed, PositionGetter |
|
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D |
|
from models.losses import * |
|
from utils.geometry import center_pointmaps, compute_rays |
|
from models.heads import head_factory |
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
if m.weight is not None: |
|
nn.init.constant_(m.weight, 1.0) |
|
elif isinstance(m, nn.Parameter): |
|
nn.init.normal_(m, std=0.02) |
|
|
|
class RayEncoder(nn.Module): |
|
def __init__(self, |
|
dim=256, |
|
patch_size=8, |
|
img_size=(128,128), |
|
depth=3, |
|
num_heads=4, |
|
pos_embed='RoPE100', |
|
): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=2, embed_dim=dim) |
|
self.dim = dim |
|
if pos_embed.startswith('RoPE'): |
|
freq = float(pos_embed[len('RoPE'):]) |
|
self.rope = RoPE2D(freq=freq) |
|
else: |
|
self.rope = None |
|
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)]) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
self.patch_embed._init_weights() |
|
|
|
|
|
self.apply(init_weights) |
|
|
|
def forward(self, rays): |
|
rays = rays.permute(0,3,1,2) |
|
rays, pos = self.patch_embed(rays) |
|
for blk in self.blocks: |
|
rays = blk(rays, pos) |
|
return rays, pos |
|
|
|
class PointmapEncoder(nn.Module): |
|
def __init__(self, |
|
dim=256, |
|
patch_size=8, |
|
img_size=(128,128), |
|
depth=3, |
|
num_heads=4, |
|
pos_embed='RoPE100', |
|
): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=3, embed_dim=dim) |
|
self.dim = dim |
|
self.patch_size = patch_size |
|
|
|
if pos_embed.startswith('RoPE'): |
|
freq = float(pos_embed[len('RoPE'):]) |
|
self.rope = RoPE2D(freq=freq) |
|
else: |
|
self.rope = None |
|
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)]) |
|
self.masked_token = nn.Parameter(torch.randn(1,1,3)) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
self.patch_embed._init_weights() |
|
|
|
|
|
self.apply(init_weights) |
|
|
|
def forward(self, pointmaps,masks=None): |
|
|
|
pointmaps[~masks] = self.masked_token.to(pointmaps.dtype).to(pointmaps.device) |
|
pointmaps = pointmaps.permute(0,3,1,2) |
|
pointmaps, pos = self.patch_embed(pointmaps) |
|
|
|
for blk in self.blocks: |
|
pointmaps = blk(pointmaps, pos) |
|
return pointmaps, pos |
|
|
|
class RayQuery(nn.Module): |
|
def __init__(self, |
|
ray_enc=RayEncoder(), |
|
pointmap_enc=PointmapEncoder(), |
|
dec_pos_embed='RoPE100', |
|
decoder_dim=256, |
|
decoder_depth=3, |
|
decoder_num_heads=4, |
|
imshape=(128,128), |
|
pts_head_type='dpt', |
|
classifier_head_type='dpt_mask', |
|
criterion=ConfLoss(L21), |
|
return_all_blocks=True, |
|
depth_mode=('exp',-float('inf'),float('inf')), |
|
conf_mode=('exp',1,float('inf')), |
|
classifier_mode=('raw',0,1), |
|
dino_layers=[23], |
|
): |
|
super().__init__() |
|
self.ray_enc = ray_enc |
|
self.pointmap_enc = pointmap_enc |
|
self.dec_depth = decoder_depth |
|
self.dec_embed_dim = decoder_dim |
|
self.enc_embed_dim = ray_enc.dim |
|
self.patch_size = pointmap_enc.patch_size |
|
self.depth_mode = depth_mode |
|
self.conf_mode = conf_mode |
|
self.classifier_mode = classifier_mode |
|
self.skip_dino = len(dino_layers) == 0 |
|
self.pts_head_type = pts_head_type |
|
self.classifier_head_type = classifier_head_type |
|
|
|
if dec_pos_embed.startswith('RoPE'): |
|
self.dec_pos_embed = RoPE2D(freq=100.0) |
|
else: |
|
raise NotImplementedError(f'{dec_pos_embed} not implemented') |
|
self.decoder_blocks = nn.ModuleList([DecoderBlock(dim=decoder_dim, num_heads=decoder_num_heads, |
|
rope=self.dec_pos_embed) for _ in range(decoder_depth)]) |
|
self.pts_head = head_factory(pts_head_type, 'pts3d', self, has_conf=True) |
|
|
|
self.classifier_head = head_factory(classifier_head_type, 'pts3d', self, has_conf=True) |
|
self.imshape = imshape |
|
self.criterion = criterion |
|
self.return_all_blocks = return_all_blocks |
|
|
|
|
|
self.dino_layers = dino_layers |
|
self.dino_proj = nn.Linear(1024 * len(dino_layers), decoder_dim) |
|
self.dino_pos_getter = PositionGetter() |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
self.apply(init_weights) |
|
|
|
def forward_encoders(self, rays, pointmaps,masks=None): |
|
|
|
rays, rays_pos = self.ray_enc(rays) |
|
|
|
|
|
B, H, W, C = pointmaps.shape |
|
pointmaps = pointmaps.reshape(B,H,W,C) |
|
pointmaps, pointmaps_pos = self.pointmap_enc(pointmaps,masks=masks) |
|
new_shape = pointmaps.shape |
|
pointmaps = pointmaps.reshape(new_shape[0],*new_shape[1:]) |
|
pointmaps_pos = pointmaps_pos[:B] |
|
|
|
return rays, rays_pos, pointmaps, pointmaps_pos |
|
|
|
def forward_decoder(self, rays, rays_pos, pointmaps, pointmaps_pos): |
|
if self.return_all_blocks: |
|
all_blocks = [] |
|
for blk in self.decoder_blocks: |
|
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos) |
|
all_blocks.append(rays) |
|
return all_blocks |
|
else: |
|
for blk in self.decoder_blocks: |
|
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos) |
|
return rays |
|
|
|
def get_dino_pos(self,dino_features): |
|
|
|
|
|
dino_H = self.imshape[0]//14 |
|
dino_W = self.imshape[1]//14 |
|
dino_pos = self.dino_pos_getter(dino_features.shape[0],dino_H,dino_W,dino_features.device) |
|
return dino_pos |
|
|
|
def forward(self,batch,mode='loss'): |
|
|
|
rays = compute_rays(batch) |
|
pointmaps_context = batch['input_cams']['pointmaps'] |
|
input_masks = batch['input_cams']['valid_masks'] |
|
|
|
|
|
rays, rays_pos, pointmaps, pointmaps_pos = self.forward_encoders(rays, pointmaps_context,masks=input_masks) |
|
|
|
if not self.skip_dino: |
|
dino_features = batch['input_cams']['dino_features'] |
|
dino_features = self.dino_proj(dino_features) |
|
if len(dino_features.shape) == 4: |
|
dino_features = dino_features.squeeze(1) |
|
dino_pos = self.get_dino_pos(dino_features) |
|
pointmaps = torch.cat([pointmaps,dino_features],dim=1) |
|
pointmaps_pos = torch.cat([pointmaps_pos,dino_pos],dim=1) |
|
else: |
|
dino_features = None |
|
dino_pos = None |
|
|
|
rays = self.forward_decoder(rays, rays_pos, pointmaps, pointmaps_pos) |
|
pts_pred_dict = self.pts_head(rays, self.imshape) |
|
classifier_pred_dict = self.classifier_head(rays, self.imshape) |
|
|
|
pred_dict = {**pts_pred_dict,**classifier_pred_dict} |
|
gt_dict = batch['new_cams'] |
|
loss_dict = self.criterion(pred_dict, gt_dict) |
|
|
|
del rays, rays_pos, pointmaps, pointmaps_pos, dino_features, dino_pos, pointmaps_context, input_masks, pts_pred_dict, classifier_pred_dict |
|
|
|
if mode == 'loss': |
|
|
|
del pred_dict, gt_dict |
|
return loss_dict |
|
elif mode == 'viz': |
|
return pred_dict, gt_dict, loss_dict |
|
else: |
|
raise ValueError(f"Invalid mode: {mode}") |