Spaces:
Runtime error
Runtime error
Upload model/archs/decoders/shape_texture_net.py with huggingface_hub
Browse files
model/archs/decoders/shape_texture_net.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TetTexNet(nn.Module):
|
| 7 |
+
def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
|
| 8 |
+
super().__init__()
|
| 9 |
+
# self.c_dim = c_dim
|
| 10 |
+
self.plane_reso = plane_reso
|
| 11 |
+
self.padding = padding
|
| 12 |
+
self.fea_concat = fea_concat
|
| 13 |
+
|
| 14 |
+
def forward(self, rolled_out_feature, query):
|
| 15 |
+
# rolled_out_feature: rolled-out triplane feature
|
| 16 |
+
# query: queried xyz coordinates (should be scaled consistently to ptr cloud)
|
| 17 |
+
|
| 18 |
+
plane_reso = self.plane_reso
|
| 19 |
+
|
| 20 |
+
triplane_feature = dict()
|
| 21 |
+
triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
|
| 22 |
+
triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
|
| 23 |
+
triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
|
| 24 |
+
|
| 25 |
+
query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
|
| 26 |
+
query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
|
| 27 |
+
query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
|
| 28 |
+
|
| 29 |
+
if self.fea_concat:
|
| 30 |
+
query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
|
| 31 |
+
else:
|
| 32 |
+
query_feature = query_feature_xy + query_feature_yz + query_feature_zx
|
| 33 |
+
|
| 34 |
+
output = query_feature.permute(0, 2, 1)
|
| 35 |
+
|
| 36 |
+
return output
|
| 37 |
+
|
| 38 |
+
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
|
| 39 |
+
def sample_plane_feature(self, query, plane_feature, plane):
|
| 40 |
+
# CYF note:
|
| 41 |
+
# for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
|
| 42 |
+
# for training, query are essentially tetrahedra grid vertices, which are
|
| 43 |
+
# also within [-scale, scale] in the current version!
|
| 44 |
+
# xy range [-scale, scale]
|
| 45 |
+
if plane == 'xy':
|
| 46 |
+
xy = query[:, :, [0, 1]]
|
| 47 |
+
elif plane == 'yz':
|
| 48 |
+
xy = query[:, :, [1, 2]]
|
| 49 |
+
elif plane == 'zx':
|
| 50 |
+
xy = query[:, :, [2, 0]]
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("Error! Invalid plane type!")
|
| 53 |
+
|
| 54 |
+
xy = xy[:, :, None].float()
|
| 55 |
+
# not seem necessary to rescale the grid, because from
|
| 56 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
|
| 57 |
+
# it specifies sampling locations normalized by plane_feature's spatial dimension,
|
| 58 |
+
# which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
|
| 59 |
+
vgrid = 1.0 * xy
|
| 60 |
+
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
|
| 61 |
+
|
| 62 |
+
return sampled_feat
|