Spaces:
Build error
Build error
Ren Jiawei
commited on
Commit
·
ac0541e
1
Parent(s):
1c55e0d
update
Browse files
app.py
CHANGED
|
@@ -19,14 +19,14 @@ with open('shape_names.txt') as f:
|
|
| 19 |
|
| 20 |
model_gda = GDANET()
|
| 21 |
model_gda = nn.DataParallel(model_gda)
|
| 22 |
-
|
| 23 |
-
model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
| 24 |
model_gda.eval()
|
| 25 |
|
| 26 |
model_dgcnn = DGCNN()
|
| 27 |
model_dgcnn = nn.DataParallel(model_dgcnn)
|
| 28 |
-
|
| 29 |
-
model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu')))
|
| 30 |
model_dgcnn.eval()
|
| 31 |
|
| 32 |
def pyplot_draw_point_cloud(points, corruption):
|
|
@@ -68,11 +68,11 @@ def load_dataset(corruption_idx, severity):
|
|
| 68 |
]
|
| 69 |
corruption_type = corruptions[corruption_idx]
|
| 70 |
if corruption_type == 'clean':
|
| 71 |
-
|
| 72 |
-
f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5'))
|
| 73 |
else:
|
| 74 |
-
|
| 75 |
-
f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5'))
|
| 76 |
data = f['data'][:].astype('float32')
|
| 77 |
label = f['label'][:].astype('int64')
|
| 78 |
f.close()
|
|
|
|
| 19 |
|
| 20 |
model_gda = GDANET()
|
| 21 |
model_gda = nn.DataParallel(model_gda)
|
| 22 |
+
model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
| 23 |
+
# model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
| 24 |
model_gda.eval()
|
| 25 |
|
| 26 |
model_dgcnn = DGCNN()
|
| 27 |
model_dgcnn = nn.DataParallel(model_dgcnn)
|
| 28 |
+
model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu')))
|
| 29 |
+
# model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu')))
|
| 30 |
model_dgcnn.eval()
|
| 31 |
|
| 32 |
def pyplot_draw_point_cloud(points, corruption):
|
|
|
|
| 68 |
]
|
| 69 |
corruption_type = corruptions[corruption_idx]
|
| 70 |
if corruption_type == 'clean':
|
| 71 |
+
f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5'))
|
| 72 |
+
# f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5'))
|
| 73 |
else:
|
| 74 |
+
f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5'))
|
| 75 |
+
# f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5'))
|
| 76 |
data = f['data'][:].astype('float32')
|
| 77 |
label = f['label'][:].astype('int64')
|
| 78 |
f.close()
|