Mariam-Elz commited on
Commit
4b0f5e0
·
verified ·
1 Parent(s): 9d1f8fa

Upload model/archs/decoders/shape_texture_net.py with huggingface_hub

Browse files
model/archs/decoders/shape_texture_net.py CHANGED
@@ -1,62 +1,62 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class TetTexNet(nn.Module):
7
- def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
8
- super().__init__()
9
- # self.c_dim = c_dim
10
- self.plane_reso = plane_reso
11
- self.padding = padding
12
- self.fea_concat = fea_concat
13
-
14
- def forward(self, rolled_out_feature, query):
15
- # rolled_out_feature: rolled-out triplane feature
16
- # query: queried xyz coordinates (should be scaled consistently to ptr cloud)
17
-
18
- plane_reso = self.plane_reso
19
-
20
- triplane_feature = dict()
21
- triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
22
- triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
23
- triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
24
-
25
- query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
26
- query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
27
- query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
28
-
29
- if self.fea_concat:
30
- query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
31
- else:
32
- query_feature = query_feature_xy + query_feature_yz + query_feature_zx
33
-
34
- output = query_feature.permute(0, 2, 1)
35
-
36
- return output
37
-
38
- # uses values from plane_feature and pixel locations from vgrid to interpolate feature
39
- def sample_plane_feature(self, query, plane_feature, plane):
40
- # CYF note:
41
- # for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
42
- # for training, query are essentially tetrahedra grid vertices, which are
43
- # also within [-scale, scale] in the current version!
44
- # xy range [-scale, scale]
45
- if plane == 'xy':
46
- xy = query[:, :, [0, 1]]
47
- elif plane == 'yz':
48
- xy = query[:, :, [1, 2]]
49
- elif plane == 'zx':
50
- xy = query[:, :, [2, 0]]
51
- else:
52
- raise ValueError("Error! Invalid plane type!")
53
-
54
- xy = xy[:, :, None].float()
55
- # not seem necessary to rescale the grid, because from
56
- # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
57
- # it specifies sampling locations normalized by plane_feature's spatial dimension,
58
- # which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
59
- vgrid = 1.0 * xy
60
- sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
61
-
62
- return sampled_feat
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class TetTexNet(nn.Module):
7
+ def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
8
+ super().__init__()
9
+ # self.c_dim = c_dim
10
+ self.plane_reso = plane_reso
11
+ self.padding = padding
12
+ self.fea_concat = fea_concat
13
+
14
+ def forward(self, rolled_out_feature, query):
15
+ # rolled_out_feature: rolled-out triplane feature
16
+ # query: queried xyz coordinates (should be scaled consistently to ptr cloud)
17
+
18
+ plane_reso = self.plane_reso
19
+
20
+ triplane_feature = dict()
21
+ triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
22
+ triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
23
+ triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
24
+
25
+ query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
26
+ query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
27
+ query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
28
+
29
+ if self.fea_concat:
30
+ query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
31
+ else:
32
+ query_feature = query_feature_xy + query_feature_yz + query_feature_zx
33
+
34
+ output = query_feature.permute(0, 2, 1)
35
+
36
+ return output
37
+
38
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
39
+ def sample_plane_feature(self, query, plane_feature, plane):
40
+ # CYF note:
41
+ # for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
42
+ # for training, query are essentially tetrahedra grid vertices, which are
43
+ # also within [-scale, scale] in the current version!
44
+ # xy range [-scale, scale]
45
+ if plane == 'xy':
46
+ xy = query[:, :, [0, 1]]
47
+ elif plane == 'yz':
48
+ xy = query[:, :, [1, 2]]
49
+ elif plane == 'zx':
50
+ xy = query[:, :, [2, 0]]
51
+ else:
52
+ raise ValueError("Error! Invalid plane type!")
53
+
54
+ xy = xy[:, :, None].float()
55
+ # not seem necessary to rescale the grid, because from
56
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
57
+ # it specifies sampling locations normalized by plane_feature's spatial dimension,
58
+ # which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
59
+ vgrid = 1.0 * xy
60
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
61
+
62
+ return sampled_feat