rayst3r / models /heads /linear_head.py
bartduis's picture
init
70d1188
raw
history blame
1.48 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .postprocess import postprocess
class LinearPts3d (nn.Module):
"""
Linear head for dust3r
Each token outputs: - 16x16 3D points (+ confidence)
"""
def __init__(self, net, has_conf=False,mode='pts3d'):
super().__init__()
self.patch_size = net.patch_size
self.depth_mode = net.depth_mode
self.conf_mode = net.conf_mode
self.has_conf = has_conf
self.mode = mode
self.classifier_mode = None
if self.mode == 'pts3d':
self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
elif self.mode == 'depth':
self.proj = nn.Linear(net.dec_embed_dim, (1 + has_conf)*self.patch_size**2)
elif self.mode == 'classifier':
self.proj = nn.Linear(net.dec_embed_dim, (1 + has_conf)*self.patch_size**2)
self.classifier_mode = net.classifier_mode
def setup(self, croconet):
pass
def forward(self, decout, img_shape):
H, W = img_shape
tokens = decout[-1]
B, S, D = tokens.shape
# extract 3D points
feat = self.proj(tokens) # B,S,D
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
# permute + norm depth
return postprocess(feat, self.depth_mode, self.conf_mode,self.classifier_mode)