Spaces:
Runtime error
Runtime error
lyndonzheng
commited on
Commit
•
13398c7
1
Parent(s):
3afeb70
add photos
Browse files- app.py +1 -1
- demo_examples/blenheim_palace.JPG +0 -0
- demo_examples/blenheim_palace_bedroom.png +0 -0
- demo_examples/blenheim_palace_living.png +0 -0
- demo_examples/christ_church_cathedral.png +0 -0
- demo_examples/radcliffe.png +0 -0
- flash3d/networks/gaussian_predictor.py +9 -7
- flash3d/util/vis3d.py +1 -6
app.py
CHANGED
@@ -93,7 +93,7 @@ def main():
|
|
93 |
gr.Markdown(
|
94 |
"""
|
95 |
# Flash3D
|
96 |
-
**Flash3D** [[project page](https://www.robots.ox.ac.uk/~vgg/research/flash3d/)] is a fast, super efficient, trinable on a single GPU in a day for
|
97 |
The model used in the demo was trained on only **RealEstate10k dataset on a single A6000 GPU within 1 day**.
|
98 |
Upload an image of a scene or click on one of the provided examples to see how the Flash3D does.
|
99 |
The 3D viewer will render a .ply scene exported from the 3D Gaussians, which is only an approximation.
|
|
|
93 |
gr.Markdown(
|
94 |
"""
|
95 |
# Flash3D
|
96 |
+
**Flash3D** [[project page](https://www.robots.ox.ac.uk/~vgg/research/flash3d/)] is a fast, super efficient, trinable on a single GPU in a day for scene 3D reconstruction from a single image.
|
97 |
The model used in the demo was trained on only **RealEstate10k dataset on a single A6000 GPU within 1 day**.
|
98 |
Upload an image of a scene or click on one of the provided examples to see how the Flash3D does.
|
99 |
The 3D viewer will render a .ply scene exported from the 3D Gaussians, which is only an approximation.
|
demo_examples/blenheim_palace.JPG
ADDED
demo_examples/blenheim_palace_bedroom.png
ADDED
demo_examples/blenheim_palace_living.png
ADDED
demo_examples/christ_church_cathedral.png
ADDED
demo_examples/radcliffe.png
ADDED
flash3d/networks/gaussian_predictor.py
CHANGED
@@ -73,23 +73,25 @@ class GaussianPredictor(nn.Module):
|
|
73 |
self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
|
74 |
|
75 |
self.models = nn.ModuleDict(models)
|
|
|
76 |
|
|
|
77 |
backproject_depth = {}
|
78 |
-
H = cfg.dataset.height
|
79 |
-
W = cfg.dataset.width
|
80 |
-
for scale in cfg.model.scales:
|
81 |
h = H // (2 ** scale)
|
82 |
w = W // (2 ** scale)
|
83 |
-
if cfg.model.shift_rays_half_pixel == "zero":
|
84 |
shift_rays_half_pixel = 0
|
85 |
-
elif cfg.model.shift_rays_half_pixel == "forward":
|
86 |
shift_rays_half_pixel = 0.5
|
87 |
-
elif cfg.model.shift_rays_half_pixel == "backward":
|
88 |
shift_rays_half_pixel = -0.5
|
89 |
else:
|
90 |
raise NotImplementedError
|
91 |
backproject_depth[str(scale)] = BackprojectDepth(
|
92 |
-
cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel,
|
93 |
# backprojection can be different if padding was used
|
94 |
h + 2 * self.cfg.dataset.pad_border_aug,
|
95 |
w + 2 * self.cfg.dataset.pad_border_aug,
|
|
|
73 |
self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
|
74 |
|
75 |
self.models = nn.ModuleDict(models)
|
76 |
+
self.set_backproject()
|
77 |
|
78 |
+
def set_backproject(self):
|
79 |
backproject_depth = {}
|
80 |
+
H = self.cfg.dataset.height
|
81 |
+
W = self.cfg.dataset.width
|
82 |
+
for scale in self.cfg.model.scales:
|
83 |
h = H // (2 ** scale)
|
84 |
w = W // (2 ** scale)
|
85 |
+
if self.cfg.model.shift_rays_half_pixel == "zero":
|
86 |
shift_rays_half_pixel = 0
|
87 |
+
elif self.cfg.model.shift_rays_half_pixel == "forward":
|
88 |
shift_rays_half_pixel = 0.5
|
89 |
+
elif self.cfg.model.shift_rays_half_pixel == "backward":
|
90 |
shift_rays_half_pixel = -0.5
|
91 |
else:
|
92 |
raise NotImplementedError
|
93 |
backproject_depth[str(scale)] = BackprojectDepth(
|
94 |
+
self.cfg.optimiser.batch_size * self.cfg.model.gaussians_per_pixel,
|
95 |
# backprojection can be different if padding was used
|
96 |
h + 2 * self.cfg.dataset.pad_border_aug,
|
97 |
w + 2 * self.cfg.dataset.pad_border_aug,
|
flash3d/util/vis3d.py
CHANGED
@@ -107,11 +107,9 @@ def export_ply(
|
|
107 |
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
|
108 |
|
109 |
|
110 |
-
def save_ply(outputs, path, num_gauss=3):
|
111 |
-
pad = 32
|
112 |
|
113 |
def crop_r(t):
|
114 |
-
h, w = 256, 384
|
115 |
H = h + pad * 2
|
116 |
W = w + pad * 2
|
117 |
t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
|
@@ -120,14 +118,11 @@ def save_ply(outputs, path, num_gauss=3):
|
|
120 |
return t
|
121 |
|
122 |
def crop(t):
|
123 |
-
h, w = 256, 384
|
124 |
H = h + pad * 2
|
125 |
W = w + pad * 2
|
126 |
t = t[..., pad:H-pad, pad:W-pad]
|
127 |
return t
|
128 |
|
129 |
-
# import pdb
|
130 |
-
# pdb.set_trace()
|
131 |
means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
|
132 |
scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
133 |
rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
|
|
107 |
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
|
108 |
|
109 |
|
110 |
+
def save_ply(outputs, path, num_gauss=3, h=256, w=384, pad=32):
|
|
|
111 |
|
112 |
def crop_r(t):
|
|
|
113 |
H = h + pad * 2
|
114 |
W = w + pad * 2
|
115 |
t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
|
|
|
118 |
return t
|
119 |
|
120 |
def crop(t):
|
|
|
121 |
H = h + pad * 2
|
122 |
W = w + pad * 2
|
123 |
t = t[..., pad:H-pad, pad:W-pad]
|
124 |
return t
|
125 |
|
|
|
|
|
126 |
means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
|
127 |
scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
128 |
rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|