Spaces:
Running
on
Zero
Running
on
Zero
Paul Engstler
commited on
Commit
·
84eee5b
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +148 -0
- README.md +10 -0
- app.py +257 -0
- examples/photo-1469559845082-95b66baaf023.jpeg +0 -0
- examples/photo-1499916078039-922301b0eb9b.jpeg +0 -0
- examples/photo-1514984879728-be0aff75a6e8.jpeg +0 -0
- examples/photo-1546975490-e8b92a360b24.jpeg +0 -0
- examples/photo-1618197345638-d2df92b39fe1.jpeg +0 -0
- examples/photo-1628624747186-a941c476b7ef.jpeg +0 -0
- examples/photo-1667788000333-4e36f948de9a.jpeg +0 -0
- packages.txt +1 -0
- pre-requirements.txt +0 -0
- requirements.txt +26 -0
- utils/demo.py +54 -0
- utils/gaussian_renderer/__init__.py +100 -0
- utils/gaussian_renderer/network_gui.py +86 -0
- utils/gs.py +196 -0
- utils/models.py +119 -0
- utils/ops.py +95 -0
- utils/render.py +112 -0
- utils/scene/__init__.py +92 -0
- utils/scene/cameras.py +76 -0
- utils/scene/colmap_loader.py +294 -0
- utils/scene/dataset_readers.py +270 -0
- utils/scene/gaussian_model.py +416 -0
- utils/scene/utils/camera_utils.py +84 -0
- utils/scene/utils/general_utils.py +133 -0
- utils/scene/utils/graphics_utils.py +88 -0
- utils/scene/utils/image_utils.py +19 -0
- utils/scene/utils/loss_utils.py +65 -0
- utils/scene/utils/sh_utils.py +118 -0
- utils/scene/utils/system_utils.py +28 -0
- zoedepth/LICENSE +21 -0
- zoedepth/data/__init__.py +24 -0
- zoedepth/data/data_mono.py +697 -0
- zoedepth/data/ddad.py +117 -0
- zoedepth/data/diml_indoor_test.py +125 -0
- zoedepth/data/diml_outdoor_test.py +114 -0
- zoedepth/data/diode.py +125 -0
- zoedepth/data/hypersim.py +138 -0
- zoedepth/data/ibims.py +81 -0
- zoedepth/data/marigold_nyu.py +112 -0
- zoedepth/data/places365.py +118 -0
- zoedepth/data/preprocess.py +154 -0
- zoedepth/data/sun_rgbd_loader.py +106 -0
- zoedepth/data/transforms.py +481 -0
- zoedepth/data/vkitti.py +151 -0
- zoedepth/data/vkitti2.py +187 -0
- zoedepth/models/__init__.py +24 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.png
|
2 |
+
**.gif
|
3 |
+
.vscode/
|
4 |
+
*.rdb
|
5 |
+
**.xml
|
6 |
+
wandb/
|
7 |
+
slurm/
|
8 |
+
tmp/
|
9 |
+
.logs/
|
10 |
+
checkpoints/
|
11 |
+
external_jobs/
|
12 |
+
# Byte-compiled / optimized / DLL files
|
13 |
+
__pycache__/
|
14 |
+
*.py[cod]
|
15 |
+
*$py.class
|
16 |
+
ptlflow_logs/
|
17 |
+
output/
|
18 |
+
log/
|
19 |
+
.idea/
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
results/
|
23 |
+
**.DS_Store
|
24 |
+
**.pt
|
25 |
+
demo/
|
26 |
+
# Distribution / packaging
|
27 |
+
.Python
|
28 |
+
build/
|
29 |
+
develop-eggs/
|
30 |
+
dist/
|
31 |
+
downloads/
|
32 |
+
eggs/
|
33 |
+
.eggs/
|
34 |
+
lib/
|
35 |
+
lib64/
|
36 |
+
parts/
|
37 |
+
sdist/
|
38 |
+
var/
|
39 |
+
wheels/
|
40 |
+
pip-wheel-metadata/
|
41 |
+
share/python-wheels/
|
42 |
+
*.egg-info/
|
43 |
+
.installed.cfg
|
44 |
+
*.egg
|
45 |
+
MANIFEST
|
46 |
+
~shortcuts/
|
47 |
+
**/wandb_logs/
|
48 |
+
**.db
|
49 |
+
# PyInstaller
|
50 |
+
# Usually these files are written by a python script from a template
|
51 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
52 |
+
*.manifest
|
53 |
+
*.spec
|
54 |
+
|
55 |
+
# Installer logs
|
56 |
+
pip-log.txt
|
57 |
+
pip-delete-this-directory.txt
|
58 |
+
|
59 |
+
# Unit test / coverage reports
|
60 |
+
htmlcov/
|
61 |
+
.tox/
|
62 |
+
.nox/
|
63 |
+
.coverage
|
64 |
+
.coverage.*
|
65 |
+
.cache
|
66 |
+
nosetests.xml
|
67 |
+
coverage.xml
|
68 |
+
*.cover
|
69 |
+
*.py,cover
|
70 |
+
.hypothesis/
|
71 |
+
.pytest_cache/
|
72 |
+
|
73 |
+
# Translations
|
74 |
+
*.mo
|
75 |
+
*.pot
|
76 |
+
|
77 |
+
# Django stuff:
|
78 |
+
*.log
|
79 |
+
local_settings.py
|
80 |
+
db.sqlite3
|
81 |
+
db.sqlite3-journal
|
82 |
+
|
83 |
+
# Flask stuff:
|
84 |
+
instance/
|
85 |
+
.webassets-cache
|
86 |
+
|
87 |
+
# Scrapy stuff:
|
88 |
+
.scrapy
|
89 |
+
|
90 |
+
# Sphinx documentation
|
91 |
+
docs/_build/
|
92 |
+
|
93 |
+
# PyBuilder
|
94 |
+
target/
|
95 |
+
|
96 |
+
# Jupyter Notebook
|
97 |
+
.ipynb_checkpoints
|
98 |
+
|
99 |
+
# IPython
|
100 |
+
profile_default/
|
101 |
+
ipython_config.py
|
102 |
+
|
103 |
+
# pyenv
|
104 |
+
.python-version
|
105 |
+
|
106 |
+
# pipenv
|
107 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
108 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
109 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
110 |
+
# install all needed dependencies.
|
111 |
+
#Pipfile.lock
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Invisible Stitch
|
3 |
+
emoji: 🪡
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
|
4 |
+
# this is a HF Spaces specific hack, as
|
5 |
+
# (i) building pytorch3d with GPU support is a bit tricky here
|
6 |
+
# (ii) installing the wheel via requirements.txt breaks ZeroGPU
|
7 |
+
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import skimage
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
from utils.render import PointsRendererWithMasks, render
|
20 |
+
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
|
21 |
+
from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle
|
22 |
+
|
23 |
+
from pytorch3d.utils import opencv_from_cameras_projection
|
24 |
+
from utils.ops import focal2fov, fov2focal
|
25 |
+
from utils.models import infer_with_zoe_dc
|
26 |
+
from utils.scene import GaussianModel
|
27 |
+
from utils.demo import downsample_point_cloud
|
28 |
+
from typing import Iterable, Tuple, Dict, Optional
|
29 |
+
import itertools
|
30 |
+
|
31 |
+
from pytorch3d.structures import Pointclouds
|
32 |
+
from pytorch3d.renderer import (
|
33 |
+
look_at_view_transform,
|
34 |
+
PerspectiveCameras,
|
35 |
+
)
|
36 |
+
|
37 |
+
from pytorch3d.io import IO
|
38 |
+
|
39 |
+
def get_blank_gs_bundle(h, w):
|
40 |
+
return {
|
41 |
+
"camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
|
42 |
+
"W": w,
|
43 |
+
"H": h,
|
44 |
+
"pcd_points": None,
|
45 |
+
"pcd_colors": None,
|
46 |
+
'frames': [],
|
47 |
+
}
|
48 |
+
|
49 |
+
@spaces.GPU(duration=30)
|
50 |
+
def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs):
|
51 |
+
w, h = image_size
|
52 |
+
optimization_bundle_frames = []
|
53 |
+
|
54 |
+
for azim, elev, dist, at in look_at_params:
|
55 |
+
R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at)
|
56 |
+
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
|
57 |
+
|
58 |
+
if point_cloud is not None:
|
59 |
+
images, masks, depths = render(cameras, point_cloud, **render_kwargs)
|
60 |
+
|
61 |
+
if not dry_run:
|
62 |
+
eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1))
|
63 |
+
eroded_depth = depths[0].clone()
|
64 |
+
eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0
|
65 |
+
|
66 |
+
outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0))
|
67 |
+
|
68 |
+
aligned_depth = torch.from_numpy(aligned_depth).to(device)
|
69 |
+
|
70 |
+
else:
|
71 |
+
# in a dry run, we do not actually outpaint the image
|
72 |
+
outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))
|
73 |
+
|
74 |
+
else:
|
75 |
+
assert initial_image is not None
|
76 |
+
assert not dry_run
|
77 |
+
|
78 |
+
# jumpstart the point cloud with a regular depth estimation
|
79 |
+
t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
|
80 |
+
depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))
|
81 |
+
outpainted_img = initial_image
|
82 |
+
images = [t_initial_image.to(device)]
|
83 |
+
masks = [torch.ones(h, w, dtype=torch.bool).to(device)]
|
84 |
+
|
85 |
+
if not dry_run:
|
86 |
+
# snap high gradients to nearest neighbor, which eliminates noodle artifacts
|
87 |
+
aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu()
|
88 |
+
xy_depth_world = project_points(cameras, aligned_depth)
|
89 |
+
|
90 |
+
c2w = cameras.get_world_to_view_transform().get_matrix()[0]
|
91 |
+
|
92 |
+
optimization_bundle_frames.append({
|
93 |
+
"image": outpainted_img,
|
94 |
+
"mask": masks[0].cpu().numpy(),
|
95 |
+
"transform_matrix": c2w.tolist(),
|
96 |
+
"azim": azim,
|
97 |
+
"elev": elev,
|
98 |
+
"dist": dist,
|
99 |
+
})
|
100 |
+
|
101 |
+
if discard_mask:
|
102 |
+
optimization_bundle_frames[-1].pop("mask")
|
103 |
+
|
104 |
+
if not dry_run:
|
105 |
+
optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
|
106 |
+
optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
|
107 |
+
optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()
|
108 |
+
|
109 |
+
else:
|
110 |
+
# in a dry run, we do not modify the point cloud
|
111 |
+
continue
|
112 |
+
|
113 |
+
rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)
|
114 |
+
|
115 |
+
if point_cloud is None:
|
116 |
+
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
|
117 |
+
|
118 |
+
else:
|
119 |
+
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
|
120 |
+
# in theory, 1 pixel is sufficient but we use 2 to be safe
|
121 |
+
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)
|
122 |
+
|
123 |
+
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
|
124 |
+
|
125 |
+
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
|
126 |
+
|
127 |
+
return optimization_bundle_frames, point_cloud
|
128 |
+
|
129 |
+
@spaces.GPU(duration=30)
|
130 |
+
def generate_point_cloud(initial_image: Image.Image, prompt: str):
|
131 |
+
image_size = initial_image.size
|
132 |
+
w, h = image_size
|
133 |
+
|
134 |
+
optimization_bundle = get_blank_gs_bundle(h, w)
|
135 |
+
|
136 |
+
step_size = 25
|
137 |
+
|
138 |
+
azim_steps = [0, step_size, -step_size]
|
139 |
+
look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps]
|
140 |
+
|
141 |
+
optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True)
|
142 |
+
|
143 |
+
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
|
144 |
+
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
|
145 |
+
|
146 |
+
return optimization_bundle, point_cloud
|
147 |
+
|
148 |
+
@spaces.GPU(duration=30)
|
149 |
+
def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str):
|
150 |
+
w, h = optimization_bundle["W"], optimization_bundle["H"]
|
151 |
+
|
152 |
+
supporting_frames = []
|
153 |
+
|
154 |
+
for i, frame in enumerate(optimization_bundle["frames"]):
|
155 |
+
# skip supporting views
|
156 |
+
if frame.get("supporting", False):
|
157 |
+
continue
|
158 |
+
|
159 |
+
center_point = torch.tensor(frame["center_point"]).to(device)
|
160 |
+
mean_depth = frame["mean_depth"]
|
161 |
+
azim, elev = frame["azim"], frame["elev"]
|
162 |
+
|
163 |
+
azim_jitters = torch.linspace(-5, 5, 3).tolist()
|
164 |
+
elev_jitters = torch.linspace(-5, 5, 3).tolist()
|
165 |
+
|
166 |
+
# build the product of azim and elev jitters
|
167 |
+
camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)]
|
168 |
+
|
169 |
+
look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters]
|
170 |
+
|
171 |
+
local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3)
|
172 |
+
|
173 |
+
for local_supporting_frame in local_supporting_frames:
|
174 |
+
local_supporting_frame["supporting"] = True
|
175 |
+
|
176 |
+
supporting_frames.extend(local_supporting_frames)
|
177 |
+
|
178 |
+
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
|
179 |
+
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
|
180 |
+
|
181 |
+
return optimization_bundle, point_cloud
|
182 |
+
|
183 |
+
@spaces.GPU(duration=30)
|
184 |
+
def generate_scene(img: Image.Image, prompt: str):
|
185 |
+
assert isinstance(img, Image.Image)
|
186 |
+
|
187 |
+
# resize image maintaining the aspect ratio so the longest side is 720 pixels
|
188 |
+
max_size = 720
|
189 |
+
img.thumbnail((max_size, max_size))
|
190 |
+
|
191 |
+
# crop to ensure the image dimensions are divisible by 8
|
192 |
+
img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8))
|
193 |
+
|
194 |
+
from hashlib import sha1
|
195 |
+
from datetime import datetime
|
196 |
+
|
197 |
+
run_id = sha1(datetime.now().isoformat().encode()).hexdigest()[:6]
|
198 |
+
|
199 |
+
run_name = f"gradio_{run_id}"
|
200 |
+
|
201 |
+
gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt)
|
202 |
+
|
203 |
+
downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device)
|
204 |
+
|
205 |
+
gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy()
|
206 |
+
gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy()
|
207 |
+
|
208 |
+
scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options)
|
209 |
+
|
210 |
+
scene.gaussians._opacity = torch.ones_like(scene.gaussians._opacity)
|
211 |
+
#scene = run_gaussian_splatting(scene, gs_optimization_bundle)
|
212 |
+
|
213 |
+
# coordinate system transformation
|
214 |
+
scene.gaussians._xyz = scene.gaussians._xyz.detach()
|
215 |
+
scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1]
|
216 |
+
scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2]
|
217 |
+
|
218 |
+
save_path = os.path.join("outputs", f"{run_name}.ply")
|
219 |
+
|
220 |
+
scene.gaussians.save_ply(save_path)
|
221 |
+
|
222 |
+
return save_path
|
223 |
+
|
224 |
+
if __name__ == "__main__":
|
225 |
+
global device
|
226 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
227 |
+
|
228 |
+
from utils.models import get_zoe_dc_model, get_sd_pipeline
|
229 |
+
|
230 |
+
global zoe_dc_model
|
231 |
+
from huggingface_hub import hf_hub_download
|
232 |
+
zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)
|
233 |
+
|
234 |
+
global pipe
|
235 |
+
pipe = get_sd_pipeline().to(device)
|
236 |
+
|
237 |
+
demo = gr.Interface(
|
238 |
+
fn=generate_scene,
|
239 |
+
inputs=[
|
240 |
+
gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"),
|
241 |
+
gr.Textbox(label="Scene Hallucination Prompt")
|
242 |
+
],
|
243 |
+
outputs=gr.Model3D(label="Generated Scene"),
|
244 |
+
allow_flagging="never",
|
245 |
+
title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting",
|
246 |
+
description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](#) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).",
|
247 |
+
article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>",
|
248 |
+
examples=[
|
249 |
+
["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"],
|
250 |
+
["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"],
|
251 |
+
["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"],
|
252 |
+
["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"],
|
253 |
+
["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"],
|
254 |
+
["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"],
|
255 |
+
["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"],
|
256 |
+
])
|
257 |
+
demo.queue().launch(share=True)
|
examples/photo-1469559845082-95b66baaf023.jpeg
ADDED
![]() |
examples/photo-1499916078039-922301b0eb9b.jpeg
ADDED
![]() |
examples/photo-1514984879728-be0aff75a6e8.jpeg
ADDED
![]() |
examples/photo-1546975490-e8b92a360b24.jpeg
ADDED
![]() |
examples/photo-1618197345638-d2df92b39fe1.jpeg
ADDED
![]() |
examples/photo-1628624747186-a941c476b7ef.jpeg
ADDED
![]() |
examples/photo-1667788000333-4e36f948de9a.jpeg
ADDED
![]() |
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-dev
|
pre-requirements.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.19.0
|
2 |
+
diffusers==0.26.3
|
3 |
+
fire==0.5.0
|
4 |
+
gradio==4.27.0
|
5 |
+
h5py==3.10.0
|
6 |
+
huggingface_hub==0.22.2
|
7 |
+
imageio==2.33.1
|
8 |
+
jaxtyping==0.2.28
|
9 |
+
matplotlib==3.7.5
|
10 |
+
numpy==1.22.4
|
11 |
+
opencv_python==4.8.0.76
|
12 |
+
pandas==1.5.1
|
13 |
+
Pillow==10.3.0
|
14 |
+
plyfile==1.0.3
|
15 |
+
scipy==1.8.1
|
16 |
+
scikit-image
|
17 |
+
submitit==1.5.1
|
18 |
+
tqdm==4.66.1
|
19 |
+
trimesh==3.21.7
|
20 |
+
wandb==0.16.3
|
21 |
+
xformers==0.0.25
|
22 |
+
spaces
|
23 |
+
timm==0.6.7
|
24 |
+
transformers==4.37.2
|
25 |
+
accelerate==0.27.2
|
26 |
+
easydict
|
utils/demo.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import skimage
|
6 |
+
from pytorch3d.renderer import (
|
7 |
+
look_at_view_transform,
|
8 |
+
PerspectiveCameras,
|
9 |
+
)
|
10 |
+
|
11 |
+
from .render import render
|
12 |
+
from .ops import project_points, get_pointcloud, merge_pointclouds
|
13 |
+
|
14 |
+
def downsample_point_cloud(optimization_bundle, device="cpu"):
|
15 |
+
point_cloud = None
|
16 |
+
|
17 |
+
for i, frame in enumerate(optimization_bundle["frames"]):
|
18 |
+
if frame.get("supporting", False):
|
19 |
+
continue
|
20 |
+
|
21 |
+
downsampled_image = copy.deepcopy(frame["image"])
|
22 |
+
downsampled_image.thumbnail((360, 360))
|
23 |
+
|
24 |
+
image_size = downsampled_image.size
|
25 |
+
w, h = image_size
|
26 |
+
|
27 |
+
# regenerate the point cloud at a lower resolution
|
28 |
+
R, T = look_at_view_transform(device=device, azim=frame["azim"], elev=frame["elev"], dist=frame["dist"])#, dist=1+0.15*step)
|
29 |
+
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
|
30 |
+
|
31 |
+
# downsample the depth
|
32 |
+
downsampled_depth = torch.nn.functional.interpolate(torch.tensor(frame["depth"]).unsqueeze(0).unsqueeze(0).float().to(device), size=(h, w), mode="nearest").squeeze()
|
33 |
+
|
34 |
+
xy_depth_world = project_points(cameras, downsampled_depth)
|
35 |
+
|
36 |
+
rgb = (torch.from_numpy(np.asarray(downsampled_image).copy()).reshape(-1, 3).float() / 255).to(device)
|
37 |
+
|
38 |
+
c2w = cameras.get_world_to_view_transform().get_matrix()[0]
|
39 |
+
|
40 |
+
if i == 0:
|
41 |
+
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
|
42 |
+
|
43 |
+
else:
|
44 |
+
images, masks, depths = render(cameras, point_cloud, radius=1e-2)
|
45 |
+
|
46 |
+
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
|
47 |
+
# in theory, 1 pixel is sufficient but we use 2 to be safe
|
48 |
+
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(1))).to(device)
|
49 |
+
|
50 |
+
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
|
51 |
+
|
52 |
+
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
|
53 |
+
|
54 |
+
return point_cloud
|
utils/gaussian_renderer/__init__.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
15 |
+
from scene.gaussian_model import GaussianModel
|
16 |
+
from utils.sh_utils import eval_sh
|
17 |
+
|
18 |
+
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
|
19 |
+
"""
|
20 |
+
Render the scene.
|
21 |
+
|
22 |
+
Background tensor (bg_color) must be on GPU!
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
26 |
+
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
|
27 |
+
try:
|
28 |
+
screenspace_points.retain_grad()
|
29 |
+
except:
|
30 |
+
pass
|
31 |
+
|
32 |
+
# Set up rasterization configuration
|
33 |
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
34 |
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
35 |
+
|
36 |
+
raster_settings = GaussianRasterizationSettings(
|
37 |
+
image_height=int(viewpoint_camera.image_height),
|
38 |
+
image_width=int(viewpoint_camera.image_width),
|
39 |
+
tanfovx=tanfovx,
|
40 |
+
tanfovy=tanfovy,
|
41 |
+
bg=bg_color,
|
42 |
+
scale_modifier=scaling_modifier,
|
43 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
44 |
+
projmatrix=viewpoint_camera.full_proj_transform,
|
45 |
+
sh_degree=pc.active_sh_degree,
|
46 |
+
campos=viewpoint_camera.camera_center,
|
47 |
+
prefiltered=False,
|
48 |
+
debug=pipe.debug
|
49 |
+
)
|
50 |
+
|
51 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
52 |
+
|
53 |
+
means3D = pc.get_xyz
|
54 |
+
means2D = screenspace_points
|
55 |
+
opacity = pc.get_opacity
|
56 |
+
|
57 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
58 |
+
# scaling / rotation by the rasterizer.
|
59 |
+
scales = None
|
60 |
+
rotations = None
|
61 |
+
cov3D_precomp = None
|
62 |
+
if pipe.compute_cov3D_python:
|
63 |
+
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
64 |
+
else:
|
65 |
+
scales = pc.get_scaling
|
66 |
+
rotations = pc.get_rotation
|
67 |
+
|
68 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
69 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
70 |
+
shs = None
|
71 |
+
colors_precomp = None
|
72 |
+
if override_color is None:
|
73 |
+
if pipe.convert_SHs_python:
|
74 |
+
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
|
75 |
+
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
|
76 |
+
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
|
77 |
+
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
78 |
+
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
79 |
+
else:
|
80 |
+
shs = pc.get_features
|
81 |
+
else:
|
82 |
+
colors_precomp = override_color
|
83 |
+
|
84 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
85 |
+
rendered_image, radii = rasterizer(
|
86 |
+
means3D = means3D,
|
87 |
+
means2D = means2D,
|
88 |
+
shs = shs,
|
89 |
+
colors_precomp = colors_precomp,
|
90 |
+
opacities = opacity,
|
91 |
+
scales = scales,
|
92 |
+
rotations = rotations,
|
93 |
+
cov3D_precomp = cov3D_precomp)
|
94 |
+
|
95 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
96 |
+
# They will be excluded from value updates used in the splitting criteria.
|
97 |
+
return {"render": rendered_image,
|
98 |
+
"viewspace_points": screenspace_points,
|
99 |
+
"visibility_filter" : radii > 0,
|
100 |
+
"radii": radii}
|
utils/gaussian_renderer/network_gui.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import traceback
|
14 |
+
import socket
|
15 |
+
import json
|
16 |
+
from scene.cameras import MiniCam
|
17 |
+
|
18 |
+
host = "127.0.0.1"
|
19 |
+
port = 6009
|
20 |
+
|
21 |
+
conn = None
|
22 |
+
addr = None
|
23 |
+
|
24 |
+
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
25 |
+
|
26 |
+
def init(wish_host, wish_port):
|
27 |
+
global host, port, listener
|
28 |
+
host = wish_host
|
29 |
+
port = wish_port
|
30 |
+
listener.bind((host, port))
|
31 |
+
listener.listen()
|
32 |
+
listener.settimeout(0)
|
33 |
+
|
34 |
+
def try_connect():
|
35 |
+
global conn, addr, listener
|
36 |
+
try:
|
37 |
+
conn, addr = listener.accept()
|
38 |
+
print(f"\nConnected by {addr}")
|
39 |
+
conn.settimeout(None)
|
40 |
+
except Exception as inst:
|
41 |
+
pass
|
42 |
+
|
43 |
+
def read():
|
44 |
+
global conn
|
45 |
+
messageLength = conn.recv(4)
|
46 |
+
messageLength = int.from_bytes(messageLength, 'little')
|
47 |
+
message = conn.recv(messageLength)
|
48 |
+
return json.loads(message.decode("utf-8"))
|
49 |
+
|
50 |
+
def send(message_bytes, verify):
|
51 |
+
global conn
|
52 |
+
if message_bytes != None:
|
53 |
+
conn.sendall(message_bytes)
|
54 |
+
conn.sendall(len(verify).to_bytes(4, 'little'))
|
55 |
+
conn.sendall(bytes(verify, 'ascii'))
|
56 |
+
|
57 |
+
def receive():
|
58 |
+
message = read()
|
59 |
+
|
60 |
+
width = message["resolution_x"]
|
61 |
+
height = message["resolution_y"]
|
62 |
+
|
63 |
+
if width != 0 and height != 0:
|
64 |
+
try:
|
65 |
+
do_training = bool(message["train"])
|
66 |
+
fovy = message["fov_y"]
|
67 |
+
fovx = message["fov_x"]
|
68 |
+
znear = message["z_near"]
|
69 |
+
zfar = message["z_far"]
|
70 |
+
do_shs_python = bool(message["shs_python"])
|
71 |
+
do_rot_scale_python = bool(message["rot_scale_python"])
|
72 |
+
keep_alive = bool(message["keep_alive"])
|
73 |
+
scaling_modifier = message["scaling_modifier"]
|
74 |
+
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
|
75 |
+
world_view_transform[:,1] = -world_view_transform[:,1]
|
76 |
+
world_view_transform[:,2] = -world_view_transform[:,2]
|
77 |
+
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
|
78 |
+
full_proj_transform[:,1] = -full_proj_transform[:,1]
|
79 |
+
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
|
80 |
+
except Exception as e:
|
81 |
+
print("")
|
82 |
+
traceback.print_exc()
|
83 |
+
raise e
|
84 |
+
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
|
85 |
+
else:
|
86 |
+
return None, None, None, None, None, None
|
utils/gs.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from .scene import GaussianModel
|
5 |
+
from .scene.dataset_readers import SceneInfo, getNerfppNorm
|
6 |
+
from .scene.cameras import Camera
|
7 |
+
from .ops import focal2fov, fov2focal
|
8 |
+
from .scene.gaussian_model import BasicPointCloud
|
9 |
+
from easydict import EasyDict as edict
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
|
14 |
+
def get_blank_gs_bundle(h, w):
|
15 |
+
return {
|
16 |
+
"camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
|
17 |
+
"W": w,
|
18 |
+
"H": h,
|
19 |
+
"pcd_points": None,
|
20 |
+
"pcd_colors": None,
|
21 |
+
'frames': [],
|
22 |
+
}
|
23 |
+
|
24 |
+
def read_cameras_from_optimization_bundle(optimization_bundle, white_background: bool = False):
|
25 |
+
cameras = []
|
26 |
+
|
27 |
+
fovx = optimization_bundle["camera_angle_x"]
|
28 |
+
frames = optimization_bundle["frames"]
|
29 |
+
|
30 |
+
# we flip the x and y axis to move from PyTorch3D's coordinate system to COLMAP's
|
31 |
+
coordinate_system_transform = np.array([-1, -1, 1])
|
32 |
+
|
33 |
+
for idx, frame in enumerate(frames):
|
34 |
+
c2w = np.array(frame["transform_matrix"])
|
35 |
+
c2w[:3, :3] = c2w[:3, :3] * coordinate_system_transform
|
36 |
+
|
37 |
+
# get the world-to-camera transform and set R, T
|
38 |
+
w2c = np.linalg.inv(c2w)
|
39 |
+
R = np.transpose(w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
|
40 |
+
T = c2w[-1, :3] * coordinate_system_transform
|
41 |
+
|
42 |
+
image = frame["image"]
|
43 |
+
|
44 |
+
im_data = np.array(image.convert("RGBA"))
|
45 |
+
|
46 |
+
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
47 |
+
|
48 |
+
norm_data = im_data / 255.0
|
49 |
+
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
50 |
+
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
51 |
+
|
52 |
+
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
53 |
+
FovY = fovy
|
54 |
+
FovX = fovx
|
55 |
+
|
56 |
+
image = torch.Tensor(arr).permute(2,0,1)
|
57 |
+
|
58 |
+
cameras.append(Camera(colmap_id=idx, R=R, T=T, FoVx=FovX, FoVy=FovY, image=image, mask=frame.get("mask", None),
|
59 |
+
gt_alpha_mask=None, image_name='', uid=idx, data_device='cuda'))
|
60 |
+
|
61 |
+
return cameras
|
62 |
+
|
63 |
+
class Scene:
|
64 |
+
gaussians: GaussianModel
|
65 |
+
|
66 |
+
def __init__(self, traindata, gaussians: GaussianModel, gs_options, shuffle: bool = True):
|
67 |
+
self.traindata = traindata
|
68 |
+
self.gaussians = gaussians
|
69 |
+
|
70 |
+
train_cameras = read_cameras_from_optimization_bundle(traindata, gs_options.white_background)
|
71 |
+
|
72 |
+
nerf_normalization = getNerfppNorm(train_cameras)
|
73 |
+
|
74 |
+
pcd = BasicPointCloud(points=traindata['pcd_points'], colors=traindata['pcd_colors'], normals=None)
|
75 |
+
|
76 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
77 |
+
train_cameras=train_cameras,
|
78 |
+
test_cameras=[],
|
79 |
+
nerf_normalization=nerf_normalization,
|
80 |
+
ply_path='')
|
81 |
+
|
82 |
+
if shuffle:
|
83 |
+
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
|
84 |
+
|
85 |
+
self.cameras_extent = scene_info.nerf_normalization["radius"]
|
86 |
+
|
87 |
+
self.train_cameras = scene_info.train_cameras
|
88 |
+
|
89 |
+
bg_color = np.array([1,1,1]) if gs_options.white_background else np.array([0, 0, 0])
|
90 |
+
self.background = torch.tensor(bg_color, dtype=torch.float32, device='cuda')
|
91 |
+
|
92 |
+
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
93 |
+
self.gaussians.training_setup(gs_options)
|
94 |
+
|
95 |
+
def getTrainCameras(self):
|
96 |
+
return self.train_cameras
|
97 |
+
|
98 |
+
def getPresetCameras(self, preset):
|
99 |
+
assert preset in self.preset_cameras
|
100 |
+
return self.preset_cameras[preset]
|
101 |
+
|
102 |
+
def run_gaussian_splatting(scene, gs_optimization_bundle):
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
return scene
|
106 |
+
|
107 |
+
from random import randint
|
108 |
+
from .gaussian_renderer import render as gs_render
|
109 |
+
from .scene.utils.loss_utils import l1_loss, ssim
|
110 |
+
|
111 |
+
pbar = tqdm(range(1, gs_options.iterations + 1))
|
112 |
+
for iteration in pbar:
|
113 |
+
scene.gaussians.update_learning_rate(iteration)
|
114 |
+
|
115 |
+
# Every 1000 its we increase the levels of SH up to a maximum degree
|
116 |
+
if iteration % 1000 == 0:
|
117 |
+
scene.gaussians.oneupSHdegree()
|
118 |
+
|
119 |
+
# Pick a random Camera
|
120 |
+
random_idx = randint(0, len(gs_optimization_bundle["frames"])-1)
|
121 |
+
viewpoint_cam = scene.getTrainCameras()[random_idx]
|
122 |
+
|
123 |
+
# Render
|
124 |
+
render_pkg = gs_render(viewpoint_cam, scene.gaussians, gs_options, scene.background)
|
125 |
+
image, viewspace_point_tensor, visibility_filter, radii = (
|
126 |
+
render_pkg['render'], render_pkg['viewspace_points'], render_pkg['visibility_filter'], render_pkg['radii'])
|
127 |
+
|
128 |
+
# Loss
|
129 |
+
gt_image = viewpoint_cam.original_image.cuda()
|
130 |
+
Ll1 = l1_loss(image, gt_image, reduce=False)
|
131 |
+
loss = (1.0 - gs_options.lambda_dssim) * Ll1
|
132 |
+
|
133 |
+
if viewpoint_cam.mask is not None:
|
134 |
+
mask = torch.from_numpy(viewpoint_cam.mask).to(loss.device)
|
135 |
+
else:
|
136 |
+
mask = 1
|
137 |
+
|
138 |
+
loss = (loss * mask).mean()
|
139 |
+
loss = loss + gs_options.lambda_dssim * (1.0 - ssim(image, gt_image))
|
140 |
+
loss.backward()
|
141 |
+
|
142 |
+
pbar.set_description(f"Loss: {loss.item():.4f}")
|
143 |
+
|
144 |
+
with torch.no_grad():
|
145 |
+
# Densification
|
146 |
+
if iteration < gs_options.densify_until_iter:
|
147 |
+
# Keep track of max radii in image-space for pruning
|
148 |
+
scene.gaussians.max_radii2D[visibility_filter] = torch.max(
|
149 |
+
scene.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
150 |
+
scene.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
151 |
+
|
152 |
+
if iteration > gs_options.densify_from_iter and iteration % gs_options.densification_interval == 0:
|
153 |
+
size_threshold = 20 if iteration > gs_options.opacity_reset_interval else None
|
154 |
+
scene.gaussians.densify_and_prune(
|
155 |
+
gs_options.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
156 |
+
|
157 |
+
if (iteration % gs_options.opacity_reset_interval == 0
|
158 |
+
or (gs_options.white_background and iteration == gs_options.densify_from_iter)
|
159 |
+
):
|
160 |
+
scene.gaussians.reset_opacity()
|
161 |
+
|
162 |
+
# Optimizer step
|
163 |
+
if iteration < gs_options.iterations:
|
164 |
+
scene.gaussians.optimizer.step()
|
165 |
+
scene.gaussians.optimizer.zero_grad(set_to_none = True)
|
166 |
+
|
167 |
+
return scene
|
168 |
+
|
169 |
+
gs_options = edict({
|
170 |
+
"sh_degree": 3,
|
171 |
+
"images": "images",
|
172 |
+
"resolution": -1,
|
173 |
+
"white_background": False,
|
174 |
+
"data_device": "cuda",
|
175 |
+
"eval": False,
|
176 |
+
"use_depth": False,
|
177 |
+
"iterations": 0,#250,
|
178 |
+
"position_lr_init": 0.00016,
|
179 |
+
"position_lr_final": 0.0000016,
|
180 |
+
"position_lr_delay_mult": 0.01,
|
181 |
+
"position_lr_max_steps": 2990,
|
182 |
+
"feature_lr": 0.0,#0.0025,
|
183 |
+
"opacity_lr": 0.0,#0.05,
|
184 |
+
"scaling_lr": 0.0,#0.005,
|
185 |
+
"rotation_lr": 0.0,#0.001,
|
186 |
+
"percent_dense": 0.01,
|
187 |
+
"lambda_dssim": 0.2,
|
188 |
+
"densification_interval": 100,
|
189 |
+
"opacity_reset_interval": 3000,
|
190 |
+
"densify_from_iter": 10_000,
|
191 |
+
"densify_until_iter": 15_000,
|
192 |
+
"densify_grad_threshold": 0.0002,
|
193 |
+
"convert_SHs_python": False,
|
194 |
+
"compute_cov3D_python": False,
|
195 |
+
"debug": False,
|
196 |
+
})
|
utils/models.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from zoedepth.utils.misc import colorize
|
9 |
+
from zoedepth.utils.config import get_config
|
10 |
+
from zoedepth.models.builder import build_model
|
11 |
+
from zoedepth.models.model_io import load_wts
|
12 |
+
|
13 |
+
from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
|
14 |
+
|
15 |
+
def load_ckpt(config, model, checkpoint_dir: str = "./checkpoints", ckpt_type: str = "best"):
|
16 |
+
if hasattr(config, "checkpoint"):
|
17 |
+
checkpoint = config.checkpoint
|
18 |
+
elif hasattr(config, "ckpt_pattern"):
|
19 |
+
pattern = config.ckpt_pattern
|
20 |
+
matches = glob.glob(os.path.join(
|
21 |
+
checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
|
22 |
+
if not (len(matches) > 0):
|
23 |
+
raise ValueError(f"No matches found for the pattern {pattern}")
|
24 |
+
|
25 |
+
checkpoint = matches[0]
|
26 |
+
|
27 |
+
else:
|
28 |
+
return model
|
29 |
+
model = load_wts(model, checkpoint)
|
30 |
+
print("Loaded weights from {0}".format(checkpoint))
|
31 |
+
return model
|
32 |
+
|
33 |
+
def get_zoe_dc_model(vanilla: bool = False, ckpt_path: str = None, **kwargs):
|
34 |
+
def ZoeD_N(midas_model_type="DPT_BEiT_L_384", vanilla=False, **kwargs):
|
35 |
+
if midas_model_type != "DPT_BEiT_L_384":
|
36 |
+
raise ValueError(f"Only DPT_BEiT_L_384 MiDaS model is supported for pretrained Zoe_N model, got: {midas_model_type}")
|
37 |
+
|
38 |
+
zoedepth_config = get_config("zoedepth", "train", **kwargs)
|
39 |
+
model = build_model(zoedepth_config)
|
40 |
+
|
41 |
+
if vanilla:
|
42 |
+
model.__setattr__("vanilla", True)
|
43 |
+
return model
|
44 |
+
else:
|
45 |
+
model.__setattr__("vanilla", False)
|
46 |
+
|
47 |
+
if zoedepth_config.add_depth_channel and not vanilla:
|
48 |
+
model.core.core.pretrained.model.patch_embed.proj = torch.nn.Conv2d(
|
49 |
+
model.core.core.pretrained.model.patch_embed.proj.in_channels+2,
|
50 |
+
model.core.core.pretrained.model.patch_embed.proj.out_channels,
|
51 |
+
kernel_size=model.core.core.pretrained.model.patch_embed.proj.kernel_size,
|
52 |
+
stride=model.core.core.pretrained.model.patch_embed.proj.stride,
|
53 |
+
padding=model.core.core.pretrained.model.patch_embed.proj.padding,
|
54 |
+
bias=True)
|
55 |
+
|
56 |
+
if ckpt_path is not None:
|
57 |
+
assert os.path.exists(ckpt_path)
|
58 |
+
zoedepth_config.__setattr__("checkpoint", ckpt_path)
|
59 |
+
else:
|
60 |
+
assert vanilla, "ckpt_path must be provided for non-vanilla model"
|
61 |
+
|
62 |
+
model = load_ckpt(zoedepth_config, model)
|
63 |
+
|
64 |
+
return model
|
65 |
+
|
66 |
+
return ZoeD_N(vanilla=vanilla, ckpt_path=ckpt_path, **kwargs)
|
67 |
+
|
68 |
+
def infer_with_pad(zoe, x, pad_input: bool = True, fh: float = 3, fw: float = 3, upsampling_mode: str = "bicubic", padding_mode: str = "reflect", **kwargs):
|
69 |
+
assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
|
70 |
+
|
71 |
+
if pad_input:
|
72 |
+
assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
|
73 |
+
pad_h = int(np.sqrt(x.shape[2]/2) * fh)
|
74 |
+
pad_w = int(np.sqrt(x.shape[3]/2) * fw)
|
75 |
+
padding = [pad_w, pad_w]
|
76 |
+
if pad_h > 0:
|
77 |
+
padding += [pad_h, pad_h]
|
78 |
+
|
79 |
+
x_rgb = x[:, :3]
|
80 |
+
x_remaining = x[:, 3:]
|
81 |
+
x_rgb = F.pad(x_rgb, padding, mode=padding_mode, **kwargs)
|
82 |
+
x_remaining = F.pad(x_remaining, padding, mode="constant", value=0, **kwargs)
|
83 |
+
x = torch.cat([x_rgb, x_remaining], dim=1)
|
84 |
+
out = zoe(x)["metric_depth"]
|
85 |
+
if out.shape[-2:] != x.shape[-2:]:
|
86 |
+
out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
|
87 |
+
if pad_input:
|
88 |
+
# crop to the original size, handling the case where pad_h and pad_w is 0
|
89 |
+
if pad_h > 0:
|
90 |
+
out = out[:, :, pad_h:-pad_h,:]
|
91 |
+
if pad_w > 0:
|
92 |
+
out = out[:, :, :, pad_w:-pad_w]
|
93 |
+
return out
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def infer_with_zoe_dc(zoe_dc, image, sparse_depth, scaling: float = 1):
|
97 |
+
sparse_depth_mask = (sparse_depth[None, None, ...] > 0).float()
|
98 |
+
# the metric depth range defined during training is [1e-3, 10]
|
99 |
+
x = torch.cat([image[None, ...], sparse_depth[None, None, ...] / (float(scaling) * 10.0), sparse_depth_mask], dim=1).to(zoe_dc.device)
|
100 |
+
|
101 |
+
out = infer_with_pad(zoe_dc, x)
|
102 |
+
out_flip = infer_with_pad(zoe_dc, torch.flip(x, dims=[3]))
|
103 |
+
out = (out + torch.flip(out_flip, dims=[3])) / 2
|
104 |
+
|
105 |
+
pred_depth = float(scaling) * out
|
106 |
+
|
107 |
+
return torch.nn.functional.interpolate(pred_depth, image.shape[-2:], mode='bilinear', align_corners=True)[0, 0]
|
108 |
+
|
109 |
+
def get_sd_pipeline():
|
110 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
111 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
112 |
+
torch_dtype=torch.float16,
|
113 |
+
)
|
114 |
+
pipe.vae = AsymmetricAutoencoderKL.from_pretrained(
|
115 |
+
"cross-attention/asymmetric-autoencoder-kl-x-2",
|
116 |
+
torch_dtype=torch.float16
|
117 |
+
)
|
118 |
+
|
119 |
+
return pipe
|
utils/ops.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import skimage
|
4 |
+
from scipy import ndimage
|
5 |
+
from PIL import Image
|
6 |
+
from .models import infer_with_zoe_dc
|
7 |
+
from pytorch3d.structures import Pointclouds
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
def nearest_neighbor_fill(img, mask, erosion=0):
|
12 |
+
img_ = np.copy(img.cpu().numpy())
|
13 |
+
|
14 |
+
if erosion > 0:
|
15 |
+
eroded_mask = skimage.morphology.binary_erosion(mask.cpu().numpy(), footprint=skimage.morphology.disk(erosion))
|
16 |
+
else:
|
17 |
+
eroded_mask = mask.cpu().numpy()
|
18 |
+
|
19 |
+
img_[eroded_mask <= 0] = np.nan
|
20 |
+
|
21 |
+
distance_to_boundary = ndimage.distance_transform_bf((~eroded_mask>0), metric="cityblock")
|
22 |
+
|
23 |
+
for current_dist in np.unique(distance_to_boundary)[1:]:
|
24 |
+
ii, jj = np.where(distance_to_boundary == current_dist)
|
25 |
+
|
26 |
+
ii_ = np.array([ii - 1, ii, ii + 1, ii - 1, ii, ii + 1, ii - 1, ii, ii + 1]).reshape(9, -1)
|
27 |
+
jj_ = np.array([jj - 1, jj - 1, jj - 1, jj, jj, jj, jj + 1, jj + 1, jj + 1]).reshape(9, -1)
|
28 |
+
|
29 |
+
ii_ = ii_.clip(0, img_.shape[0] - 1)
|
30 |
+
jj_ = jj_.clip(0, img_.shape[1] - 1)
|
31 |
+
|
32 |
+
img_[ii, jj] = np.nanmax(img_[ii_, jj_], axis=0)
|
33 |
+
|
34 |
+
return torch.from_numpy(img_).to(img.device)
|
35 |
+
|
36 |
+
def snap_high_gradients_to_nn(depth, threshold=20):
|
37 |
+
grad_depth = np.copy(depth.cpu().numpy())
|
38 |
+
grad_depth = grad_depth - grad_depth.min()
|
39 |
+
grad_depth = grad_depth / grad_depth.max()
|
40 |
+
|
41 |
+
grad = skimage.filters.rank.gradient(grad_depth, skimage.morphology.disk(1))
|
42 |
+
return nearest_neighbor_fill(depth, torch.from_numpy(grad < threshold), erosion=3)
|
43 |
+
|
44 |
+
def project_points(cameras, depth, use_pixel_centers=True):
|
45 |
+
if len(cameras) > 1:
|
46 |
+
import warnings
|
47 |
+
warnings.warn("project_points assumes only a single camera is used")
|
48 |
+
|
49 |
+
depth_t = torch.from_numpy(depth) if isinstance(depth, np.ndarray) else depth
|
50 |
+
depth_t = depth_t.to(cameras.device)
|
51 |
+
|
52 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
53 |
+
|
54 |
+
fx, fy = cameras.focal_length[0, 1], cameras.focal_length[0, 0]
|
55 |
+
cx, cy = cameras.principal_point[0, 1], cameras.principal_point[0, 0]
|
56 |
+
|
57 |
+
i, j = torch.meshgrid(
|
58 |
+
torch.arange(cameras.image_size[0][0], dtype=torch.float32, device=cameras.device) + pixel_center,
|
59 |
+
torch.arange(cameras.image_size[0][1], dtype=torch.float32, device=cameras.device) + pixel_center,
|
60 |
+
indexing="xy",
|
61 |
+
)
|
62 |
+
|
63 |
+
directions = torch.stack(
|
64 |
+
[-(i - cx) * depth_t / fx, -(j - cy) * depth_t / fy, depth_t], -1
|
65 |
+
)
|
66 |
+
|
67 |
+
xy_depth_world = cameras.get_world_to_view_transform().inverse().transform_points(directions.view(-1, 3)).unsqueeze(0)
|
68 |
+
|
69 |
+
return xy_depth_world
|
70 |
+
|
71 |
+
def get_pointcloud(xy_depth_world, device="cpu", features=None):
|
72 |
+
point_cloud = Pointclouds(points=[xy_depth_world.to(device)], features=[features] if features is not None else None)
|
73 |
+
return point_cloud
|
74 |
+
|
75 |
+
def merge_pointclouds(point_clouds):
|
76 |
+
points = torch.cat([pc.points_padded() for pc in point_clouds], dim=1)
|
77 |
+
features = torch.cat([pc.features_padded() for pc in point_clouds], dim=1)
|
78 |
+
return Pointclouds(points=[points[0]], features=[features[0]])
|
79 |
+
|
80 |
+
def outpaint_with_depth_estimation(image, mask, previous_depth, h, w, pipe, zoe_dc, prompt, cameras, dilation_size: int = 2, depth_scaling: float = 1, generator = None):
|
81 |
+
img_input = Image.fromarray((255*image[..., :3].cpu().numpy()).astype(np.uint8))
|
82 |
+
|
83 |
+
# we slightly dilate the mask as aliasing might cause us to receive a too small mask from pytorch3d
|
84 |
+
img_mask = Image.fromarray((255*skimage.morphology.isotropic_dilation(((~mask).cpu().numpy()), radius=dilation_size)).astype(np.uint8))#footprint=skimage.morphology.disk(dilation_size)))
|
85 |
+
|
86 |
+
out_image = pipe(prompt=prompt, image=img_input, mask_image=img_mask, height=h, width=w, generator=generator).images[0]
|
87 |
+
out_depth = infer_with_zoe_dc(zoe_dc, torch.from_numpy(np.asarray(out_image)/255.).permute(2,0,1).float().to(zoe_dc.device), (previous_depth * mask).to(zoe_dc.device), scaling=depth_scaling).cpu().numpy()
|
88 |
+
|
89 |
+
return out_image, out_depth
|
90 |
+
|
91 |
+
def fov2focal(fov, pixels):
|
92 |
+
return pixels / (2 * math.tan(fov / 2))
|
93 |
+
|
94 |
+
def focal2fov(focal, pixels):
|
95 |
+
return 2*math.atan(pixels/(2*focal))
|
utils/render.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import skimage
|
3 |
+
from pytorch3d.structures import Pointclouds
|
4 |
+
from pytorch3d.renderer import (
|
5 |
+
look_at_view_transform,
|
6 |
+
FoVOrthographicCameras,
|
7 |
+
FoVPerspectiveCameras,
|
8 |
+
PerspectiveCameras,
|
9 |
+
PointsRasterizationSettings,
|
10 |
+
PointsRenderer,
|
11 |
+
PulsarPointsRenderer,
|
12 |
+
PointsRasterizer,
|
13 |
+
AlphaCompositor,
|
14 |
+
NormWeightedCompositor
|
15 |
+
)
|
16 |
+
from .ops import nearest_neighbor_fill
|
17 |
+
|
18 |
+
from typing import cast, Optional
|
19 |
+
|
20 |
+
class PointsRendererWithMasks(PointsRenderer):
|
21 |
+
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
22 |
+
fragments = self.rasterizer(point_clouds, **kwargs)
|
23 |
+
|
24 |
+
# Construct weights based on the distance of a point to the true point.
|
25 |
+
# However, this could be done differently: e.g. predicted as opposed
|
26 |
+
# to a function of the weights.
|
27 |
+
r = self.rasterizer.raster_settings.radius
|
28 |
+
|
29 |
+
dists2 = fragments.dists
|
30 |
+
weights = torch.ones_like(dists2)#1 - dists2 / (r * r)
|
31 |
+
ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float()
|
32 |
+
|
33 |
+
weights = weights * ok
|
34 |
+
|
35 |
+
fragments_prm = fragments.idx.long().permute(0, 3, 1, 2)
|
36 |
+
weights_prm = weights.permute(0, 3, 1, 2)
|
37 |
+
images = self.compositor(
|
38 |
+
fragments_prm,
|
39 |
+
weights_prm,
|
40 |
+
point_clouds.features_packed().permute(1, 0),
|
41 |
+
**kwargs,
|
42 |
+
)
|
43 |
+
|
44 |
+
cumprod = torch.cumprod(1 - weights, dim=-1)
|
45 |
+
cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
|
46 |
+
depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
|
47 |
+
|
48 |
+
# permute so image comes at the end
|
49 |
+
images = images.permute(0, 2, 3, 1)
|
50 |
+
masks = fragments.idx.long()[..., 0] >= 0
|
51 |
+
|
52 |
+
return images, masks, depths
|
53 |
+
|
54 |
+
def render_with_settings(cameras, point_cloud, raster_settings, antialiasing: int = 1):
|
55 |
+
if antialiasing > 1:
|
56 |
+
raster_settings.image_size = (raster_settings.image_size[0] * antialiasing, raster_settings.image_size[1] * antialiasing)
|
57 |
+
|
58 |
+
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
59 |
+
|
60 |
+
renderer = PointsRendererWithMasks(
|
61 |
+
rasterizer=rasterizer,
|
62 |
+
compositor=AlphaCompositor()
|
63 |
+
)
|
64 |
+
|
65 |
+
if antialiasing > 1:
|
66 |
+
images, masks, depths = renderer(point_cloud)
|
67 |
+
|
68 |
+
images = images.permute(0, 3, 1, 2) # NHWC -> NCHW
|
69 |
+
images = F.avg_pool2d(images, kernel_size=antialiasing, stride=antialiasing)
|
70 |
+
images = images.permute(0, 2, 3, 1) # NCHW -> NHWC
|
71 |
+
|
72 |
+
else:
|
73 |
+
return renderer(point_cloud)
|
74 |
+
|
75 |
+
|
76 |
+
def render(cameras, point_cloud, fill_point_cloud_holes: bool = False, radius: Optional[float] = None, antialiasing: int = 1):
|
77 |
+
if fill_point_cloud_holes:
|
78 |
+
coarse_raster_settings = PointsRasterizationSettings(
|
79 |
+
image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
|
80 |
+
radius = 1e-2,
|
81 |
+
points_per_pixel = 1
|
82 |
+
)
|
83 |
+
|
84 |
+
_, coarse_mask, _ = render_with_settings(cameras, point_cloud, coarse_raster_settings)
|
85 |
+
|
86 |
+
eroded_coarse_mask = torch.from_numpy(skimage.morphology.binary_erosion(coarse_mask[0].cpu().numpy(), footprint=skimage.morphology.disk(2)))
|
87 |
+
|
88 |
+
raster_settings = PointsRasterizationSettings(
|
89 |
+
image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
|
90 |
+
radius = (1 / float(max(cameras.image_size[0, 1], cameras.image_size[0, 0])) * 2.0) if radius is None else radius,
|
91 |
+
points_per_pixel = 16
|
92 |
+
)
|
93 |
+
|
94 |
+
# Render the scene
|
95 |
+
images, masks, depths = render_with_settings(cameras, point_cloud, raster_settings)
|
96 |
+
|
97 |
+
holes_in_rendering = masks[0].cpu() ^ eroded_coarse_mask
|
98 |
+
|
99 |
+
images[0] = nearest_neighbor_fill(images[0], ~holes_in_rendering, 0)
|
100 |
+
depths[0] = nearest_neighbor_fill(depths[0], ~holes_in_rendering, 0)
|
101 |
+
|
102 |
+
return images, eroded_coarse_mask.unsqueeze(0).to(masks.device), depths
|
103 |
+
|
104 |
+
else:
|
105 |
+
raster_settings = PointsRasterizationSettings(
|
106 |
+
image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
|
107 |
+
radius = (1 / float(max(cameras.image_size[0, 1], cameras.image_size[0, 0])) * 2.0) if radius is None else radius,
|
108 |
+
points_per_pixel = 16
|
109 |
+
)
|
110 |
+
|
111 |
+
# Render the scene
|
112 |
+
return render_with_settings(cameras, point_cloud, raster_settings)
|
utils/scene/__init__.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import json
|
15 |
+
from .utils.system_utils import searchForMaxIteration
|
16 |
+
from .dataset_readers import sceneLoadTypeCallbacks
|
17 |
+
from .gaussian_model import GaussianModel
|
18 |
+
from .utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
|
19 |
+
|
20 |
+
class Scene:
|
21 |
+
|
22 |
+
gaussians : GaussianModel
|
23 |
+
|
24 |
+
def __init__(self, args, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
|
25 |
+
"""b
|
26 |
+
:param path: Path to colmap scene main folder.
|
27 |
+
"""
|
28 |
+
self.model_path = args.model_path
|
29 |
+
self.loaded_iter = None
|
30 |
+
self.gaussians = gaussians
|
31 |
+
|
32 |
+
if load_iteration:
|
33 |
+
if load_iteration == -1:
|
34 |
+
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
|
35 |
+
else:
|
36 |
+
self.loaded_iter = load_iteration
|
37 |
+
print("Loading trained model at iteration {}".format(self.loaded_iter))
|
38 |
+
|
39 |
+
self.train_cameras = {}
|
40 |
+
self.test_cameras = {}
|
41 |
+
|
42 |
+
if os.path.exists(os.path.join(args.source_path, "sparse")):
|
43 |
+
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
|
44 |
+
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
|
45 |
+
print("Found transforms_train.json file, assuming Blender data set!")
|
46 |
+
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
|
47 |
+
else:
|
48 |
+
assert False, "Could not recognize scene type!"
|
49 |
+
|
50 |
+
if not self.loaded_iter:
|
51 |
+
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
|
52 |
+
dest_file.write(src_file.read())
|
53 |
+
json_cams = []
|
54 |
+
camlist = []
|
55 |
+
if scene_info.test_cameras:
|
56 |
+
camlist.extend(scene_info.test_cameras)
|
57 |
+
if scene_info.train_cameras:
|
58 |
+
camlist.extend(scene_info.train_cameras)
|
59 |
+
for id, cam in enumerate(camlist):
|
60 |
+
json_cams.append(camera_to_JSON(id, cam))
|
61 |
+
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
|
62 |
+
json.dump(json_cams, file)
|
63 |
+
|
64 |
+
if shuffle:
|
65 |
+
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
|
66 |
+
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
|
67 |
+
|
68 |
+
self.cameras_extent = scene_info.nerf_normalization["radius"]
|
69 |
+
|
70 |
+
for resolution_scale in resolution_scales:
|
71 |
+
print("Loading Training Cameras")
|
72 |
+
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
|
73 |
+
print("Loading Test Cameras")
|
74 |
+
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
|
75 |
+
|
76 |
+
if self.loaded_iter:
|
77 |
+
self.gaussians.load_ply(os.path.join(self.model_path,
|
78 |
+
"point_cloud",
|
79 |
+
"iteration_" + str(self.loaded_iter),
|
80 |
+
"point_cloud.ply"))
|
81 |
+
else:
|
82 |
+
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
83 |
+
|
84 |
+
def save(self, iteration):
|
85 |
+
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
|
86 |
+
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
|
87 |
+
|
88 |
+
def getTrainCameras(self, scale=1.0):
|
89 |
+
return self.train_cameras[scale]
|
90 |
+
|
91 |
+
def getTestCameras(self, scale=1.0):
|
92 |
+
return self.test_cameras[scale]
|
utils/scene/cameras.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
from torch import nn
|
15 |
+
import numpy as np
|
16 |
+
from .utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
17 |
+
|
18 |
+
class Camera(nn.Module):
|
19 |
+
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
20 |
+
image_name, uid, crop_box=None, mask=None,
|
21 |
+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
22 |
+
):
|
23 |
+
super(Camera, self).__init__()
|
24 |
+
|
25 |
+
self.uid = uid
|
26 |
+
self.colmap_id = colmap_id
|
27 |
+
self.R = R
|
28 |
+
self.T = T
|
29 |
+
self.FoVx = FoVx
|
30 |
+
self.FoVy = FoVy
|
31 |
+
self.image_name = image_name
|
32 |
+
self.crop_box = crop_box
|
33 |
+
self.mask = mask
|
34 |
+
|
35 |
+
try:
|
36 |
+
self.data_device = torch.device(data_device)
|
37 |
+
except Exception as e:
|
38 |
+
print(e)
|
39 |
+
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
40 |
+
self.data_device = torch.device("cuda")
|
41 |
+
|
42 |
+
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
43 |
+
self.image_width = self.original_image.shape[2]
|
44 |
+
self.image_height = self.original_image.shape[1]
|
45 |
+
|
46 |
+
self.gt_alpha_mask = gt_alpha_mask
|
47 |
+
|
48 |
+
#if gt_alpha_mask is not None:
|
49 |
+
# self.original_image *= gt_alpha_mask.to(self.data_device)
|
50 |
+
#else:
|
51 |
+
# self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
52 |
+
|
53 |
+
self.zfar = 100.0
|
54 |
+
self.znear = 0.01
|
55 |
+
|
56 |
+
self.trans = trans
|
57 |
+
self.scale = scale
|
58 |
+
|
59 |
+
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
60 |
+
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, crop_box=self.crop_box, width=self.image_width, height=self.image_height).transpose(0,1).cuda()
|
61 |
+
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
62 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
63 |
+
|
64 |
+
class MiniCam:
|
65 |
+
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
66 |
+
self.image_width = width
|
67 |
+
self.image_height = height
|
68 |
+
self.FoVy = fovy
|
69 |
+
self.FoVx = fovx
|
70 |
+
self.znear = znear
|
71 |
+
self.zfar = zfar
|
72 |
+
self.world_view_transform = world_view_transform
|
73 |
+
self.full_proj_transform = full_proj_transform
|
74 |
+
view_inv = torch.inverse(self.world_view_transform)
|
75 |
+
self.camera_center = view_inv[3][:3]
|
76 |
+
|
utils/scene/colmap_loader.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import collections
|
14 |
+
import struct
|
15 |
+
|
16 |
+
CameraModel = collections.namedtuple(
|
17 |
+
"CameraModel", ["model_id", "model_name", "num_params"])
|
18 |
+
Camera = collections.namedtuple(
|
19 |
+
"Camera", ["id", "model", "width", "height", "params"])
|
20 |
+
BaseImage = collections.namedtuple(
|
21 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
22 |
+
Point3D = collections.namedtuple(
|
23 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
24 |
+
CAMERA_MODELS = {
|
25 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
26 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
27 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
28 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
29 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
30 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
31 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
32 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
33 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
34 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
35 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
36 |
+
}
|
37 |
+
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
38 |
+
for camera_model in CAMERA_MODELS])
|
39 |
+
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
|
40 |
+
for camera_model in CAMERA_MODELS])
|
41 |
+
|
42 |
+
|
43 |
+
def qvec2rotmat(qvec):
|
44 |
+
return np.array([
|
45 |
+
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
46 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
47 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
48 |
+
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
49 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
50 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
51 |
+
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
52 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
53 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
54 |
+
|
55 |
+
def rotmat2qvec(R):
|
56 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
57 |
+
K = np.array([
|
58 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
59 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
60 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
61 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
62 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
63 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
64 |
+
if qvec[0] < 0:
|
65 |
+
qvec *= -1
|
66 |
+
return qvec
|
67 |
+
|
68 |
+
class Image(BaseImage):
|
69 |
+
def qvec2rotmat(self):
|
70 |
+
return qvec2rotmat(self.qvec)
|
71 |
+
|
72 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
73 |
+
"""Read and unpack the next bytes from a binary file.
|
74 |
+
:param fid:
|
75 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
76 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
77 |
+
:param endian_character: Any of {@, =, <, >, !}
|
78 |
+
:return: Tuple of read and unpacked values.
|
79 |
+
"""
|
80 |
+
data = fid.read(num_bytes)
|
81 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
82 |
+
|
83 |
+
def read_points3D_text(path):
|
84 |
+
"""
|
85 |
+
see: src/base/reconstruction.cc
|
86 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
87 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
88 |
+
"""
|
89 |
+
xyzs = None
|
90 |
+
rgbs = None
|
91 |
+
errors = None
|
92 |
+
num_points = 0
|
93 |
+
with open(path, "r") as fid:
|
94 |
+
while True:
|
95 |
+
line = fid.readline()
|
96 |
+
if not line:
|
97 |
+
break
|
98 |
+
line = line.strip()
|
99 |
+
if len(line) > 0 and line[0] != "#":
|
100 |
+
num_points += 1
|
101 |
+
|
102 |
+
|
103 |
+
xyzs = np.empty((num_points, 3))
|
104 |
+
rgbs = np.empty((num_points, 3))
|
105 |
+
errors = np.empty((num_points, 1))
|
106 |
+
count = 0
|
107 |
+
with open(path, "r") as fid:
|
108 |
+
while True:
|
109 |
+
line = fid.readline()
|
110 |
+
if not line:
|
111 |
+
break
|
112 |
+
line = line.strip()
|
113 |
+
if len(line) > 0 and line[0] != "#":
|
114 |
+
elems = line.split()
|
115 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
116 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
117 |
+
error = np.array(float(elems[7]))
|
118 |
+
xyzs[count] = xyz
|
119 |
+
rgbs[count] = rgb
|
120 |
+
errors[count] = error
|
121 |
+
count += 1
|
122 |
+
|
123 |
+
return xyzs, rgbs, errors
|
124 |
+
|
125 |
+
def read_points3D_binary(path_to_model_file):
|
126 |
+
"""
|
127 |
+
see: src/base/reconstruction.cc
|
128 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
129 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
130 |
+
"""
|
131 |
+
|
132 |
+
|
133 |
+
with open(path_to_model_file, "rb") as fid:
|
134 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
135 |
+
|
136 |
+
xyzs = np.empty((num_points, 3))
|
137 |
+
rgbs = np.empty((num_points, 3))
|
138 |
+
errors = np.empty((num_points, 1))
|
139 |
+
|
140 |
+
for p_id in range(num_points):
|
141 |
+
binary_point_line_properties = read_next_bytes(
|
142 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
143 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
144 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
145 |
+
error = np.array(binary_point_line_properties[7])
|
146 |
+
track_length = read_next_bytes(
|
147 |
+
fid, num_bytes=8, format_char_sequence="Q")[0]
|
148 |
+
track_elems = read_next_bytes(
|
149 |
+
fid, num_bytes=8*track_length,
|
150 |
+
format_char_sequence="ii"*track_length)
|
151 |
+
xyzs[p_id] = xyz
|
152 |
+
rgbs[p_id] = rgb
|
153 |
+
errors[p_id] = error
|
154 |
+
return xyzs, rgbs, errors
|
155 |
+
|
156 |
+
def read_intrinsics_text(path):
|
157 |
+
"""
|
158 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
159 |
+
"""
|
160 |
+
cameras = {}
|
161 |
+
with open(path, "r") as fid:
|
162 |
+
while True:
|
163 |
+
line = fid.readline()
|
164 |
+
if not line:
|
165 |
+
break
|
166 |
+
line = line.strip()
|
167 |
+
if len(line) > 0 and line[0] != "#":
|
168 |
+
elems = line.split()
|
169 |
+
camera_id = int(elems[0])
|
170 |
+
model = elems[1]
|
171 |
+
assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
|
172 |
+
width = int(elems[2])
|
173 |
+
height = int(elems[3])
|
174 |
+
params = np.array(tuple(map(float, elems[4:])))
|
175 |
+
cameras[camera_id] = Camera(id=camera_id, model=model,
|
176 |
+
width=width, height=height,
|
177 |
+
params=params)
|
178 |
+
return cameras
|
179 |
+
|
180 |
+
def read_extrinsics_binary(path_to_model_file):
|
181 |
+
"""
|
182 |
+
see: src/base/reconstruction.cc
|
183 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
184 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
185 |
+
"""
|
186 |
+
images = {}
|
187 |
+
with open(path_to_model_file, "rb") as fid:
|
188 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
189 |
+
for _ in range(num_reg_images):
|
190 |
+
binary_image_properties = read_next_bytes(
|
191 |
+
fid, num_bytes=64, format_char_sequence="idddddddi")
|
192 |
+
image_id = binary_image_properties[0]
|
193 |
+
qvec = np.array(binary_image_properties[1:5])
|
194 |
+
tvec = np.array(binary_image_properties[5:8])
|
195 |
+
camera_id = binary_image_properties[8]
|
196 |
+
image_name = ""
|
197 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
198 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
199 |
+
image_name += current_char.decode("utf-8")
|
200 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
201 |
+
num_points2D = read_next_bytes(fid, num_bytes=8,
|
202 |
+
format_char_sequence="Q")[0]
|
203 |
+
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
204 |
+
format_char_sequence="ddq"*num_points2D)
|
205 |
+
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
206 |
+
tuple(map(float, x_y_id_s[1::3]))])
|
207 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
208 |
+
images[image_id] = Image(
|
209 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
210 |
+
camera_id=camera_id, name=image_name,
|
211 |
+
xys=xys, point3D_ids=point3D_ids)
|
212 |
+
return images
|
213 |
+
|
214 |
+
|
215 |
+
def read_intrinsics_binary(path_to_model_file):
|
216 |
+
"""
|
217 |
+
see: src/base/reconstruction.cc
|
218 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
219 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
220 |
+
"""
|
221 |
+
cameras = {}
|
222 |
+
with open(path_to_model_file, "rb") as fid:
|
223 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
224 |
+
for _ in range(num_cameras):
|
225 |
+
camera_properties = read_next_bytes(
|
226 |
+
fid, num_bytes=24, format_char_sequence="iiQQ")
|
227 |
+
camera_id = camera_properties[0]
|
228 |
+
model_id = camera_properties[1]
|
229 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
230 |
+
width = camera_properties[2]
|
231 |
+
height = camera_properties[3]
|
232 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
233 |
+
params = read_next_bytes(fid, num_bytes=8*num_params,
|
234 |
+
format_char_sequence="d"*num_params)
|
235 |
+
cameras[camera_id] = Camera(id=camera_id,
|
236 |
+
model=model_name,
|
237 |
+
width=width,
|
238 |
+
height=height,
|
239 |
+
params=np.array(params))
|
240 |
+
assert len(cameras) == num_cameras
|
241 |
+
return cameras
|
242 |
+
|
243 |
+
|
244 |
+
def read_extrinsics_text(path):
|
245 |
+
"""
|
246 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
247 |
+
"""
|
248 |
+
images = {}
|
249 |
+
with open(path, "r") as fid:
|
250 |
+
while True:
|
251 |
+
line = fid.readline()
|
252 |
+
if not line:
|
253 |
+
break
|
254 |
+
line = line.strip()
|
255 |
+
if len(line) > 0 and line[0] != "#":
|
256 |
+
elems = line.split()
|
257 |
+
image_id = int(elems[0])
|
258 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
259 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
260 |
+
camera_id = int(elems[8])
|
261 |
+
image_name = elems[9]
|
262 |
+
elems = fid.readline().split()
|
263 |
+
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
264 |
+
tuple(map(float, elems[1::3]))])
|
265 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
266 |
+
images[image_id] = Image(
|
267 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
268 |
+
camera_id=camera_id, name=image_name,
|
269 |
+
xys=xys, point3D_ids=point3D_ids)
|
270 |
+
return images
|
271 |
+
|
272 |
+
|
273 |
+
def read_colmap_bin_array(path):
|
274 |
+
"""
|
275 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
|
276 |
+
|
277 |
+
:param path: path to the colmap binary file.
|
278 |
+
:return: nd array with the floating point values in the value
|
279 |
+
"""
|
280 |
+
with open(path, "rb") as fid:
|
281 |
+
width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
|
282 |
+
usecols=(0, 1, 2), dtype=int)
|
283 |
+
fid.seek(0)
|
284 |
+
num_delimiter = 0
|
285 |
+
byte = fid.read(1)
|
286 |
+
while True:
|
287 |
+
if byte == b"&":
|
288 |
+
num_delimiter += 1
|
289 |
+
if num_delimiter >= 3:
|
290 |
+
break
|
291 |
+
byte = fid.read(1)
|
292 |
+
array = np.fromfile(fid, np.float32)
|
293 |
+
array = array.reshape((width, height, channels), order="F")
|
294 |
+
return np.transpose(array, (1, 0, 2)).squeeze()
|
utils/scene/dataset_readers.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
from PIL import Image
|
15 |
+
from typing import NamedTuple
|
16 |
+
from .colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
|
17 |
+
read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
|
18 |
+
from .utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
|
19 |
+
import numpy as np
|
20 |
+
import json
|
21 |
+
from pathlib import Path
|
22 |
+
from plyfile import PlyData, PlyElement
|
23 |
+
from .utils.sh_utils import SH2RGB
|
24 |
+
from .gaussian_model import BasicPointCloud
|
25 |
+
|
26 |
+
class CameraInfo(NamedTuple):
|
27 |
+
uid: int
|
28 |
+
R: np.array
|
29 |
+
T: np.array
|
30 |
+
FovY: np.array
|
31 |
+
FovX: np.array
|
32 |
+
image: np.array
|
33 |
+
image_path: str
|
34 |
+
image_name: str
|
35 |
+
mask: np.array
|
36 |
+
mask_path: str
|
37 |
+
width: int
|
38 |
+
height: int
|
39 |
+
|
40 |
+
class SceneInfo(NamedTuple):
|
41 |
+
point_cloud: BasicPointCloud
|
42 |
+
train_cameras: list
|
43 |
+
test_cameras: list
|
44 |
+
nerf_normalization: dict
|
45 |
+
ply_path: str
|
46 |
+
|
47 |
+
def getNerfppNorm(cam_info):
|
48 |
+
def get_center_and_diag(cam_centers):
|
49 |
+
cam_centers = np.hstack(cam_centers)
|
50 |
+
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
|
51 |
+
center = avg_cam_center
|
52 |
+
dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
|
53 |
+
diagonal = np.max(dist)
|
54 |
+
return center.flatten(), diagonal
|
55 |
+
|
56 |
+
cam_centers = []
|
57 |
+
|
58 |
+
for cam in cam_info:
|
59 |
+
W2C = getWorld2View2(cam.R, cam.T)
|
60 |
+
C2W = np.linalg.inv(W2C)
|
61 |
+
cam_centers.append(C2W[:3, 3:4])
|
62 |
+
|
63 |
+
center, diagonal = get_center_and_diag(cam_centers)
|
64 |
+
radius = diagonal * 1.1
|
65 |
+
|
66 |
+
translate = -center
|
67 |
+
|
68 |
+
return {"translate": translate, "radius": radius}
|
69 |
+
|
70 |
+
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, masks_folder):
|
71 |
+
cam_infos = []
|
72 |
+
for idx, key in enumerate(cam_extrinsics):
|
73 |
+
sys.stdout.write('\r')
|
74 |
+
# the exact output you're looking for:
|
75 |
+
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
|
76 |
+
sys.stdout.flush()
|
77 |
+
|
78 |
+
extr = cam_extrinsics[key]
|
79 |
+
intr = cam_intrinsics[extr.camera_id]
|
80 |
+
height = intr.height
|
81 |
+
width = intr.width
|
82 |
+
|
83 |
+
uid = intr.id
|
84 |
+
R = np.transpose(qvec2rotmat(extr.qvec))
|
85 |
+
T = np.array(extr.tvec)
|
86 |
+
|
87 |
+
if intr.model=="SIMPLE_PINHOLE":
|
88 |
+
focal_length_x = intr.params[0]
|
89 |
+
FovY = focal2fov(focal_length_x, height)
|
90 |
+
FovX = focal2fov(focal_length_x, width)
|
91 |
+
elif intr.model=="PINHOLE":
|
92 |
+
focal_length_x = intr.params[0]
|
93 |
+
focal_length_y = intr.params[1]
|
94 |
+
FovY = focal2fov(focal_length_y, height)
|
95 |
+
FovX = focal2fov(focal_length_x, width)
|
96 |
+
else:
|
97 |
+
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
98 |
+
|
99 |
+
image_path = os.path.join(images_folder, os.path.basename(extr.name))
|
100 |
+
image_name = os.path.basename(image_path).split(".")[0]
|
101 |
+
image = Image.open(image_path)
|
102 |
+
|
103 |
+
mask_path = os.path.join(masks_folder, os.path.basename(extr.name).replace(".jpg", ".png"))
|
104 |
+
try:
|
105 |
+
mask = Image.open(mask_path)
|
106 |
+
except:
|
107 |
+
mask = None
|
108 |
+
|
109 |
+
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, mask=mask, mask_path=mask_path,
|
110 |
+
image_path=image_path, image_name=image_name, width=width, height=height)
|
111 |
+
cam_infos.append(cam_info)
|
112 |
+
sys.stdout.write('\n')
|
113 |
+
return cam_infos
|
114 |
+
|
115 |
+
def fetchPly(path):
|
116 |
+
plydata = PlyData.read(path)
|
117 |
+
vertices = plydata['vertex']
|
118 |
+
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
|
119 |
+
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
|
120 |
+
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
|
121 |
+
return BasicPointCloud(points=positions, colors=colors, normals=normals)
|
122 |
+
|
123 |
+
def storePly(path, xyz, rgb):
|
124 |
+
# Define the dtype for the structured array
|
125 |
+
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
126 |
+
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
127 |
+
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
128 |
+
|
129 |
+
normals = np.zeros_like(xyz)
|
130 |
+
|
131 |
+
elements = np.empty(xyz.shape[0], dtype=dtype)
|
132 |
+
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
133 |
+
elements[:] = list(map(tuple, attributes))
|
134 |
+
|
135 |
+
# Create the PlyData object and write to file
|
136 |
+
vertex_element = PlyElement.describe(elements, 'vertex')
|
137 |
+
ply_data = PlyData([vertex_element])
|
138 |
+
ply_data.write(path)
|
139 |
+
|
140 |
+
def readColmapSceneInfo(path, images, eval, llffhold=8):
|
141 |
+
try:
|
142 |
+
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
|
143 |
+
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
|
144 |
+
cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
|
145 |
+
cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
|
146 |
+
except:
|
147 |
+
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
|
148 |
+
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
|
149 |
+
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
|
150 |
+
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
|
151 |
+
|
152 |
+
reading_dir = "images" if images == None else images
|
153 |
+
# FIXME in post
|
154 |
+
mask_reading_dir = "masks"# if images == None else images
|
155 |
+
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), masks_folder=os.path.join(path, mask_reading_dir))
|
156 |
+
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
|
157 |
+
|
158 |
+
if eval:
|
159 |
+
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
|
160 |
+
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
|
161 |
+
else:
|
162 |
+
train_cam_infos = cam_infos
|
163 |
+
test_cam_infos = []
|
164 |
+
|
165 |
+
nerf_normalization = getNerfppNorm(train_cam_infos)
|
166 |
+
|
167 |
+
ply_path = os.path.join(path, "sparse/0/points3D.ply")
|
168 |
+
bin_path = os.path.join(path, "sparse/0/points3D.bin")
|
169 |
+
txt_path = os.path.join(path, "sparse/0/points3D.txt")
|
170 |
+
if not os.path.exists(ply_path):
|
171 |
+
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
|
172 |
+
try:
|
173 |
+
xyz, rgb, _ = read_points3D_binary(bin_path)
|
174 |
+
except:
|
175 |
+
xyz, rgb, _ = read_points3D_text(txt_path)
|
176 |
+
storePly(ply_path, xyz, rgb)
|
177 |
+
try:
|
178 |
+
pcd = fetchPly(ply_path)
|
179 |
+
except:
|
180 |
+
pcd = None
|
181 |
+
|
182 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
183 |
+
train_cameras=train_cam_infos,
|
184 |
+
test_cameras=test_cam_infos,
|
185 |
+
nerf_normalization=nerf_normalization,
|
186 |
+
ply_path=ply_path)
|
187 |
+
return scene_info
|
188 |
+
|
189 |
+
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
|
190 |
+
cam_infos = []
|
191 |
+
|
192 |
+
with open(os.path.join(path, transformsfile)) as json_file:
|
193 |
+
contents = json.load(json_file)
|
194 |
+
fovx = contents["camera_angle_x"]
|
195 |
+
|
196 |
+
frames = contents["frames"]
|
197 |
+
for idx, frame in enumerate(frames):
|
198 |
+
cam_name = os.path.join(path, frame["file_path"] + extension)
|
199 |
+
|
200 |
+
# NeRF 'transform_matrix' is a camera-to-world transform
|
201 |
+
c2w = np.array(frame["transform_matrix"])
|
202 |
+
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
203 |
+
c2w[:3, 1:3] *= -1
|
204 |
+
|
205 |
+
# get the world-to-camera transform and set R, T
|
206 |
+
w2c = np.linalg.inv(c2w)
|
207 |
+
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
208 |
+
T = w2c[:3, 3]
|
209 |
+
|
210 |
+
image_path = os.path.join(path, cam_name)
|
211 |
+
image_name = Path(cam_name).stem
|
212 |
+
image = Image.open(image_path)
|
213 |
+
|
214 |
+
im_data = np.array(image.convert("RGBA"))
|
215 |
+
|
216 |
+
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
217 |
+
|
218 |
+
norm_data = im_data / 255.0
|
219 |
+
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
220 |
+
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
221 |
+
|
222 |
+
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
223 |
+
FovY = fovy
|
224 |
+
FovX = fovx
|
225 |
+
|
226 |
+
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
227 |
+
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
|
228 |
+
|
229 |
+
return cam_infos
|
230 |
+
|
231 |
+
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
|
232 |
+
print("Reading Training Transforms")
|
233 |
+
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
|
234 |
+
print("Reading Test Transforms")
|
235 |
+
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
|
236 |
+
|
237 |
+
if not eval:
|
238 |
+
train_cam_infos.extend(test_cam_infos)
|
239 |
+
test_cam_infos = []
|
240 |
+
|
241 |
+
nerf_normalization = getNerfppNorm(train_cam_infos)
|
242 |
+
|
243 |
+
ply_path = os.path.join(path, "points3d.ply")
|
244 |
+
if not os.path.exists(ply_path):
|
245 |
+
# Since this data set has no colmap data, we start with random points
|
246 |
+
num_pts = 100_000
|
247 |
+
print(f"Generating random point cloud ({num_pts})...")
|
248 |
+
|
249 |
+
# We create random points inside the bounds of the synthetic Blender scenes
|
250 |
+
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
|
251 |
+
shs = np.random.random((num_pts, 3)) / 255.0
|
252 |
+
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
|
253 |
+
|
254 |
+
storePly(ply_path, xyz, SH2RGB(shs) * 255)
|
255 |
+
try:
|
256 |
+
pcd = fetchPly(ply_path)
|
257 |
+
except:
|
258 |
+
pcd = None
|
259 |
+
|
260 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
261 |
+
train_cameras=train_cam_infos,
|
262 |
+
test_cameras=test_cam_infos,
|
263 |
+
nerf_normalization=nerf_normalization,
|
264 |
+
ply_path=ply_path)
|
265 |
+
return scene_info
|
266 |
+
|
267 |
+
sceneLoadTypeCallbacks = {
|
268 |
+
"Colmap": readColmapSceneInfo,
|
269 |
+
"Blender" : readNerfSyntheticInfo
|
270 |
+
}
|
utils/scene/gaussian_model.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from .utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
15 |
+
from torch import nn
|
16 |
+
import os
|
17 |
+
from .utils.system_utils import mkdir_p
|
18 |
+
from plyfile import PlyData, PlyElement
|
19 |
+
from .utils.sh_utils import RGB2SH
|
20 |
+
from .utils.graphics_utils import BasicPointCloud
|
21 |
+
from .utils.general_utils import strip_symmetric, build_scaling_rotation
|
22 |
+
|
23 |
+
from scipy.spatial import KDTree
|
24 |
+
|
25 |
+
# credit to https://github.com/graphdeco-inria/gaussian-splatting/issues/292#issuecomment-2007934451
|
26 |
+
def distCUDA2(points):
|
27 |
+
points_np = points.detach().cpu().float().numpy()
|
28 |
+
dists, inds = KDTree(points_np).query(points_np, k=4)
|
29 |
+
meanDists = (dists[:, 1:] ** 2).mean(1)
|
30 |
+
|
31 |
+
return torch.tensor(meanDists, dtype=points.dtype, device=points.device)
|
32 |
+
|
33 |
+
class GaussianModel:
|
34 |
+
|
35 |
+
def setup_functions(self):
|
36 |
+
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
37 |
+
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
38 |
+
actual_covariance = L @ L.transpose(1, 2)
|
39 |
+
symm = strip_symmetric(actual_covariance)
|
40 |
+
return symm
|
41 |
+
|
42 |
+
self.scaling_activation = torch.exp
|
43 |
+
self.scaling_inverse_activation = torch.log
|
44 |
+
|
45 |
+
self.covariance_activation = build_covariance_from_scaling_rotation
|
46 |
+
|
47 |
+
self.opacity_activation = torch.sigmoid
|
48 |
+
self.inverse_opacity_activation = inverse_sigmoid
|
49 |
+
|
50 |
+
self.rotation_activation = torch.nn.functional.normalize
|
51 |
+
|
52 |
+
|
53 |
+
def __init__(self, sh_degree : int):
|
54 |
+
self.active_sh_degree = 0
|
55 |
+
self.max_sh_degree = sh_degree
|
56 |
+
self._xyz = torch.empty(0)
|
57 |
+
self._features_dc = torch.empty(0)
|
58 |
+
self._features_rest = torch.empty(0)
|
59 |
+
self._scaling = torch.empty(0)
|
60 |
+
self._rotation = torch.empty(0)
|
61 |
+
self._opacity = torch.empty(0)
|
62 |
+
self.max_radii2D = torch.empty(0)
|
63 |
+
self.xyz_gradient_accum = torch.empty(0)
|
64 |
+
self.denom = torch.empty(0)
|
65 |
+
self.optimizer = None
|
66 |
+
self.percent_dense = 0
|
67 |
+
self.spatial_lr_scale = 0
|
68 |
+
self.setup_functions()
|
69 |
+
|
70 |
+
def capture(self):
|
71 |
+
return (
|
72 |
+
self.active_sh_degree,
|
73 |
+
self._xyz,
|
74 |
+
self._features_dc,
|
75 |
+
self._features_rest,
|
76 |
+
self._scaling,
|
77 |
+
self._rotation,
|
78 |
+
self._opacity,
|
79 |
+
self.max_radii2D,
|
80 |
+
self.xyz_gradient_accum,
|
81 |
+
self.denom,
|
82 |
+
self.optimizer.state_dict(),
|
83 |
+
self.spatial_lr_scale,
|
84 |
+
)
|
85 |
+
|
86 |
+
def restore(self, model_args, training_args):
|
87 |
+
(self.active_sh_degree,
|
88 |
+
self._xyz,
|
89 |
+
self._features_dc,
|
90 |
+
self._features_rest,
|
91 |
+
self._scaling,
|
92 |
+
self._rotation,
|
93 |
+
self._opacity,
|
94 |
+
self.max_radii2D,
|
95 |
+
xyz_gradient_accum,
|
96 |
+
denom,
|
97 |
+
opt_dict,
|
98 |
+
self.spatial_lr_scale) = model_args
|
99 |
+
self.training_setup(training_args)
|
100 |
+
self.xyz_gradient_accum = xyz_gradient_accum
|
101 |
+
self.denom = denom
|
102 |
+
self.optimizer.load_state_dict(opt_dict)
|
103 |
+
|
104 |
+
@property
|
105 |
+
def get_scaling(self):
|
106 |
+
return self.scaling_activation(self._scaling)
|
107 |
+
|
108 |
+
@property
|
109 |
+
def get_rotation(self):
|
110 |
+
return self.rotation_activation(self._rotation)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def get_xyz(self):
|
114 |
+
return self._xyz
|
115 |
+
|
116 |
+
@property
|
117 |
+
def get_features(self):
|
118 |
+
features_dc = self._features_dc
|
119 |
+
features_rest = self._features_rest
|
120 |
+
return torch.cat((features_dc, features_rest), dim=1)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def get_opacity(self):
|
124 |
+
return self.opacity_activation(self._opacity)
|
125 |
+
|
126 |
+
def get_covariance(self, scaling_modifier = 1):
|
127 |
+
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
128 |
+
|
129 |
+
def oneupSHdegree(self):
|
130 |
+
if self.active_sh_degree < self.max_sh_degree:
|
131 |
+
self.active_sh_degree += 1
|
132 |
+
|
133 |
+
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
|
134 |
+
self.spatial_lr_scale = spatial_lr_scale
|
135 |
+
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
136 |
+
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
|
137 |
+
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
138 |
+
features[:, :3, 0 ] = fused_color
|
139 |
+
features[:, 3:, 1:] = 0.0
|
140 |
+
|
141 |
+
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
142 |
+
|
143 |
+
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
144 |
+
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
145 |
+
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
146 |
+
rots[:, 0] = 1
|
147 |
+
|
148 |
+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
149 |
+
|
150 |
+
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
151 |
+
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
152 |
+
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
153 |
+
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
154 |
+
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
155 |
+
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
156 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
157 |
+
|
158 |
+
def training_setup(self, training_args):
|
159 |
+
self.percent_dense = training_args.percent_dense
|
160 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
161 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
162 |
+
|
163 |
+
l = [
|
164 |
+
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
|
165 |
+
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
|
166 |
+
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
|
167 |
+
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
|
168 |
+
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
|
169 |
+
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
|
170 |
+
]
|
171 |
+
|
172 |
+
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
173 |
+
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
174 |
+
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
175 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
176 |
+
max_steps=training_args.position_lr_max_steps)
|
177 |
+
|
178 |
+
def update_learning_rate(self, iteration):
|
179 |
+
''' Learning rate scheduling per step '''
|
180 |
+
for param_group in self.optimizer.param_groups:
|
181 |
+
if param_group["name"] == "xyz":
|
182 |
+
lr = self.xyz_scheduler_args(iteration)
|
183 |
+
param_group['lr'] = lr
|
184 |
+
return lr
|
185 |
+
|
186 |
+
def construct_list_of_attributes(self):
|
187 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
188 |
+
# All channels except the 3 DC
|
189 |
+
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
190 |
+
l.append('f_dc_{}'.format(i))
|
191 |
+
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
|
192 |
+
l.append('f_rest_{}'.format(i))
|
193 |
+
l.append('opacity')
|
194 |
+
for i in range(self._scaling.shape[1]):
|
195 |
+
l.append('scale_{}'.format(i))
|
196 |
+
for i in range(self._rotation.shape[1]):
|
197 |
+
l.append('rot_{}'.format(i))
|
198 |
+
return l
|
199 |
+
|
200 |
+
def save_ply(self, path):
|
201 |
+
mkdir_p(os.path.dirname(path))
|
202 |
+
|
203 |
+
xyz = self._xyz.detach().cpu().numpy()
|
204 |
+
normals = np.zeros_like(xyz)
|
205 |
+
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
206 |
+
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
207 |
+
opacities = self._opacity.detach().cpu().numpy()
|
208 |
+
scale = self._scaling.detach().cpu().numpy()
|
209 |
+
rotation = self._rotation.detach().cpu().numpy()
|
210 |
+
|
211 |
+
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
212 |
+
|
213 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
214 |
+
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
215 |
+
elements[:] = list(map(tuple, attributes))
|
216 |
+
el = PlyElement.describe(elements, 'vertex')
|
217 |
+
PlyData([el]).write(path)
|
218 |
+
|
219 |
+
def reset_opacity(self):
|
220 |
+
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
221 |
+
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
222 |
+
self._opacity = optimizable_tensors["opacity"]
|
223 |
+
|
224 |
+
def load_ply(self, path):
|
225 |
+
plydata = PlyData.read(path)
|
226 |
+
|
227 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
228 |
+
np.asarray(plydata.elements[0]["y"]),
|
229 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
230 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
231 |
+
|
232 |
+
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
233 |
+
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
234 |
+
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
235 |
+
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
236 |
+
|
237 |
+
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
238 |
+
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
239 |
+
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
240 |
+
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
241 |
+
for idx, attr_name in enumerate(extra_f_names):
|
242 |
+
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
243 |
+
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
244 |
+
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
245 |
+
|
246 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
247 |
+
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
248 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
249 |
+
for idx, attr_name in enumerate(scale_names):
|
250 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
251 |
+
|
252 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
253 |
+
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
254 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
255 |
+
for idx, attr_name in enumerate(rot_names):
|
256 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
257 |
+
|
258 |
+
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
259 |
+
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
260 |
+
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
261 |
+
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
262 |
+
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
263 |
+
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
264 |
+
|
265 |
+
self.active_sh_degree = self.max_sh_degree
|
266 |
+
|
267 |
+
def replace_tensor_to_optimizer(self, tensor, name):
|
268 |
+
optimizable_tensors = {}
|
269 |
+
for group in self.optimizer.param_groups:
|
270 |
+
if group["name"] == name:
|
271 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
272 |
+
stored_state["exp_avg"] = torch.zeros_like(tensor)
|
273 |
+
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
|
274 |
+
|
275 |
+
del self.optimizer.state[group['params'][0]]
|
276 |
+
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
|
277 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
278 |
+
|
279 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
280 |
+
return optimizable_tensors
|
281 |
+
|
282 |
+
def _prune_optimizer(self, mask):
|
283 |
+
optimizable_tensors = {}
|
284 |
+
for group in self.optimizer.param_groups:
|
285 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
286 |
+
if stored_state is not None:
|
287 |
+
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
288 |
+
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
289 |
+
|
290 |
+
del self.optimizer.state[group['params'][0]]
|
291 |
+
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
292 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
293 |
+
|
294 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
295 |
+
else:
|
296 |
+
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
297 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
298 |
+
return optimizable_tensors
|
299 |
+
|
300 |
+
def prune_points(self, mask):
|
301 |
+
valid_points_mask = ~mask
|
302 |
+
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
303 |
+
|
304 |
+
self._xyz = optimizable_tensors["xyz"]
|
305 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
306 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
307 |
+
self._opacity = optimizable_tensors["opacity"]
|
308 |
+
self._scaling = optimizable_tensors["scaling"]
|
309 |
+
self._rotation = optimizable_tensors["rotation"]
|
310 |
+
|
311 |
+
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
|
312 |
+
|
313 |
+
self.denom = self.denom[valid_points_mask]
|
314 |
+
self.max_radii2D = self.max_radii2D[valid_points_mask]
|
315 |
+
|
316 |
+
def cat_tensors_to_optimizer(self, tensors_dict):
|
317 |
+
optimizable_tensors = {}
|
318 |
+
for group in self.optimizer.param_groups:
|
319 |
+
assert len(group["params"]) == 1
|
320 |
+
extension_tensor = tensors_dict[group["name"]]
|
321 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
322 |
+
if stored_state is not None:
|
323 |
+
|
324 |
+
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
325 |
+
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
326 |
+
|
327 |
+
del self.optimizer.state[group['params'][0]]
|
328 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
329 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
330 |
+
|
331 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
332 |
+
else:
|
333 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
334 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
335 |
+
|
336 |
+
return optimizable_tensors
|
337 |
+
|
338 |
+
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
|
339 |
+
d = {"xyz": new_xyz,
|
340 |
+
"f_dc": new_features_dc,
|
341 |
+
"f_rest": new_features_rest,
|
342 |
+
"opacity": new_opacities,
|
343 |
+
"scaling" : new_scaling,
|
344 |
+
"rotation" : new_rotation}
|
345 |
+
|
346 |
+
optimizable_tensors = self.cat_tensors_to_optimizer(d)
|
347 |
+
self._xyz = optimizable_tensors["xyz"]
|
348 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
349 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
350 |
+
self._opacity = optimizable_tensors["opacity"]
|
351 |
+
self._scaling = optimizable_tensors["scaling"]
|
352 |
+
self._rotation = optimizable_tensors["rotation"]
|
353 |
+
|
354 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
355 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
356 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
357 |
+
|
358 |
+
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
|
359 |
+
n_init_points = self.get_xyz.shape[0]
|
360 |
+
# Extract points that satisfy the gradient condition
|
361 |
+
padded_grad = torch.zeros((n_init_points), device="cuda")
|
362 |
+
padded_grad[:grads.shape[0]] = grads.squeeze()
|
363 |
+
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
|
364 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
365 |
+
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
|
366 |
+
|
367 |
+
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
|
368 |
+
means =torch.zeros((stds.size(0), 3),device="cuda")
|
369 |
+
samples = torch.normal(mean=means, std=stds)
|
370 |
+
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
|
371 |
+
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
|
372 |
+
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
|
373 |
+
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
|
374 |
+
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
|
375 |
+
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
|
376 |
+
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
|
377 |
+
|
378 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
|
379 |
+
|
380 |
+
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
|
381 |
+
self.prune_points(prune_filter)
|
382 |
+
|
383 |
+
def densify_and_clone(self, grads, grad_threshold, scene_extent):
|
384 |
+
# Extract points that satisfy the gradient condition
|
385 |
+
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
|
386 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
387 |
+
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
|
388 |
+
|
389 |
+
new_xyz = self._xyz[selected_pts_mask]
|
390 |
+
new_features_dc = self._features_dc[selected_pts_mask]
|
391 |
+
new_features_rest = self._features_rest[selected_pts_mask]
|
392 |
+
new_opacities = self._opacity[selected_pts_mask]
|
393 |
+
new_scaling = self._scaling[selected_pts_mask]
|
394 |
+
new_rotation = self._rotation[selected_pts_mask]
|
395 |
+
|
396 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
|
397 |
+
|
398 |
+
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
|
399 |
+
grads = self.xyz_gradient_accum / self.denom
|
400 |
+
grads[grads.isnan()] = 0.0
|
401 |
+
|
402 |
+
self.densify_and_clone(grads, max_grad, extent)
|
403 |
+
self.densify_and_split(grads, max_grad, extent)
|
404 |
+
|
405 |
+
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
406 |
+
if max_screen_size:
|
407 |
+
big_points_vs = self.max_radii2D > max_screen_size
|
408 |
+
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
409 |
+
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
410 |
+
self.prune_points(prune_mask)
|
411 |
+
|
412 |
+
torch.cuda.empty_cache()
|
413 |
+
|
414 |
+
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
415 |
+
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
416 |
+
self.denom[update_filter] += 1
|
utils/scene/utils/camera_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from ..cameras import Camera
|
13 |
+
import numpy as np
|
14 |
+
from .general_utils import PILtoTorch
|
15 |
+
from .graphics_utils import fov2focal
|
16 |
+
|
17 |
+
WARNED = False
|
18 |
+
|
19 |
+
def loadCam(args, id, cam_info, resolution_scale):
|
20 |
+
orig_w, orig_h = cam_info.image.size
|
21 |
+
|
22 |
+
if args.resolution in [1, 2, 4, 8]:
|
23 |
+
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
24 |
+
else: # should be a type that converts to float
|
25 |
+
if args.resolution == -1:
|
26 |
+
if orig_w > 1600:
|
27 |
+
global WARNED
|
28 |
+
if not WARNED:
|
29 |
+
print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
|
30 |
+
"If this is not desired, please explicitly specify '--resolution/-r' as 1")
|
31 |
+
WARNED = True
|
32 |
+
global_down = orig_w / 1600
|
33 |
+
else:
|
34 |
+
global_down = 1
|
35 |
+
else:
|
36 |
+
global_down = orig_w / args.resolution
|
37 |
+
|
38 |
+
scale = float(global_down) * float(resolution_scale)
|
39 |
+
resolution = (int(orig_w / scale), int(orig_h / scale))
|
40 |
+
|
41 |
+
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
42 |
+
|
43 |
+
gt_image = resized_image_rgb[:3, ...]
|
44 |
+
loaded_mask = None
|
45 |
+
|
46 |
+
if resized_image_rgb.shape[1] == 4:
|
47 |
+
loaded_mask = resized_image_rgb[3:4, ...]
|
48 |
+
elif cam_info.mask is not None:
|
49 |
+
loaded_mask = ~(PILtoTorch(cam_info.mask, resolution)[0:1, ...] > 0)
|
50 |
+
|
51 |
+
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
52 |
+
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
53 |
+
image=gt_image, gt_alpha_mask=loaded_mask,
|
54 |
+
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
55 |
+
|
56 |
+
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
57 |
+
camera_list = []
|
58 |
+
|
59 |
+
for id, c in enumerate(cam_infos):
|
60 |
+
camera_list.append(loadCam(args, id, c, resolution_scale))
|
61 |
+
|
62 |
+
return camera_list
|
63 |
+
|
64 |
+
def camera_to_JSON(id, camera : Camera):
|
65 |
+
Rt = np.zeros((4, 4))
|
66 |
+
Rt[:3, :3] = camera.R.transpose()
|
67 |
+
Rt[:3, 3] = camera.T
|
68 |
+
Rt[3, 3] = 1.0
|
69 |
+
|
70 |
+
W2C = np.linalg.inv(Rt)
|
71 |
+
pos = W2C[:3, 3]
|
72 |
+
rot = W2C[:3, :3]
|
73 |
+
serializable_array_2d = [x.tolist() for x in rot]
|
74 |
+
camera_entry = {
|
75 |
+
'id' : id,
|
76 |
+
'img_name' : camera.image_name,
|
77 |
+
'width' : camera.width,
|
78 |
+
'height' : camera.height,
|
79 |
+
'position': pos.tolist(),
|
80 |
+
'rotation': serializable_array_2d,
|
81 |
+
'fy' : fov2focal(camera.FovY, camera.height),
|
82 |
+
'fx' : fov2focal(camera.FovX, camera.width)
|
83 |
+
}
|
84 |
+
return camera_entry
|
utils/scene/utils/general_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import sys
|
14 |
+
from datetime import datetime
|
15 |
+
import numpy as np
|
16 |
+
import random
|
17 |
+
|
18 |
+
def inverse_sigmoid(x):
|
19 |
+
return torch.log(x/(1-x))
|
20 |
+
|
21 |
+
def PILtoTorch(pil_image, resolution):
|
22 |
+
resized_image_PIL = pil_image.resize(resolution)
|
23 |
+
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
24 |
+
if len(resized_image.shape) == 3:
|
25 |
+
return resized_image.permute(2, 0, 1)
|
26 |
+
else:
|
27 |
+
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
28 |
+
|
29 |
+
def get_expon_lr_func(
|
30 |
+
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Copied from Plenoxels
|
34 |
+
|
35 |
+
Continuous learning rate decay function. Adapted from JaxNeRF
|
36 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
37 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
38 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
39 |
+
function of lr_delay_mult, such that the initial learning rate is
|
40 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
41 |
+
to the normal learning rate when steps>lr_delay_steps.
|
42 |
+
:param conf: config subtree 'lr' or similar
|
43 |
+
:param max_steps: int, the number of steps during optimization.
|
44 |
+
:return HoF which takes step as input
|
45 |
+
"""
|
46 |
+
|
47 |
+
def helper(step):
|
48 |
+
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
49 |
+
# Disable this parameter
|
50 |
+
return 0.0
|
51 |
+
if lr_delay_steps > 0:
|
52 |
+
# A kind of reverse cosine decay.
|
53 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
54 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
delay_rate = 1.0
|
58 |
+
t = np.clip(step / max_steps, 0, 1)
|
59 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
60 |
+
return delay_rate * log_lerp
|
61 |
+
|
62 |
+
return helper
|
63 |
+
|
64 |
+
def strip_lowerdiag(L):
|
65 |
+
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
66 |
+
|
67 |
+
uncertainty[:, 0] = L[:, 0, 0]
|
68 |
+
uncertainty[:, 1] = L[:, 0, 1]
|
69 |
+
uncertainty[:, 2] = L[:, 0, 2]
|
70 |
+
uncertainty[:, 3] = L[:, 1, 1]
|
71 |
+
uncertainty[:, 4] = L[:, 1, 2]
|
72 |
+
uncertainty[:, 5] = L[:, 2, 2]
|
73 |
+
return uncertainty
|
74 |
+
|
75 |
+
def strip_symmetric(sym):
|
76 |
+
return strip_lowerdiag(sym)
|
77 |
+
|
78 |
+
def build_rotation(r):
|
79 |
+
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
80 |
+
|
81 |
+
q = r / norm[:, None]
|
82 |
+
|
83 |
+
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
84 |
+
|
85 |
+
r = q[:, 0]
|
86 |
+
x = q[:, 1]
|
87 |
+
y = q[:, 2]
|
88 |
+
z = q[:, 3]
|
89 |
+
|
90 |
+
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
91 |
+
R[:, 0, 1] = 2 * (x*y - r*z)
|
92 |
+
R[:, 0, 2] = 2 * (x*z + r*y)
|
93 |
+
R[:, 1, 0] = 2 * (x*y + r*z)
|
94 |
+
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
95 |
+
R[:, 1, 2] = 2 * (y*z - r*x)
|
96 |
+
R[:, 2, 0] = 2 * (x*z - r*y)
|
97 |
+
R[:, 2, 1] = 2 * (y*z + r*x)
|
98 |
+
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
99 |
+
return R
|
100 |
+
|
101 |
+
def build_scaling_rotation(s, r):
|
102 |
+
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
103 |
+
R = build_rotation(r)
|
104 |
+
|
105 |
+
L[:,0,0] = s[:,0]
|
106 |
+
L[:,1,1] = s[:,1]
|
107 |
+
L[:,2,2] = s[:,2]
|
108 |
+
|
109 |
+
L = R @ L
|
110 |
+
return L
|
111 |
+
|
112 |
+
def safe_state(silent):
|
113 |
+
old_f = sys.stdout
|
114 |
+
class F:
|
115 |
+
def __init__(self, silent):
|
116 |
+
self.silent = silent
|
117 |
+
|
118 |
+
def write(self, x):
|
119 |
+
if not self.silent:
|
120 |
+
if x.endswith("\n"):
|
121 |
+
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
|
122 |
+
else:
|
123 |
+
old_f.write(x)
|
124 |
+
|
125 |
+
def flush(self):
|
126 |
+
old_f.flush()
|
127 |
+
|
128 |
+
sys.stdout = F(silent)
|
129 |
+
|
130 |
+
random.seed(0)
|
131 |
+
np.random.seed(0)
|
132 |
+
torch.manual_seed(0)
|
133 |
+
torch.cuda.set_device(torch.device("cuda:0"))
|
utils/scene/utils/graphics_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
import numpy as np
|
15 |
+
from typing import NamedTuple
|
16 |
+
|
17 |
+
class BasicPointCloud(NamedTuple):
|
18 |
+
points : np.array
|
19 |
+
colors : np.array
|
20 |
+
normals : np.array
|
21 |
+
|
22 |
+
def geom_transform_points(points, transf_matrix):
|
23 |
+
P, _ = points.shape
|
24 |
+
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
|
25 |
+
points_hom = torch.cat([points, ones], dim=1)
|
26 |
+
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
|
27 |
+
|
28 |
+
denom = points_out[..., 3:] + 0.0000001
|
29 |
+
return (points_out[..., :3] / denom).squeeze(dim=0)
|
30 |
+
|
31 |
+
def getWorld2View(R, t):
|
32 |
+
Rt = np.zeros((4, 4))
|
33 |
+
Rt[:3, :3] = R.transpose()
|
34 |
+
Rt[:3, 3] = t
|
35 |
+
Rt[3, 3] = 1.0
|
36 |
+
return np.float32(Rt)
|
37 |
+
|
38 |
+
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
39 |
+
Rt = np.zeros((4, 4))
|
40 |
+
Rt[:3, :3] = R.transpose()
|
41 |
+
Rt[:3, 3] = t
|
42 |
+
Rt[3, 3] = 1.0
|
43 |
+
|
44 |
+
C2W = np.linalg.inv(Rt)
|
45 |
+
cam_center = C2W[:3, 3]
|
46 |
+
cam_center = (cam_center + translate) * scale
|
47 |
+
C2W[:3, 3] = cam_center
|
48 |
+
Rt = np.linalg.inv(C2W)
|
49 |
+
return np.float32(Rt)
|
50 |
+
|
51 |
+
def getProjectionMatrix(znear, zfar, fovX, fovY, crop_box=None, width=None, height=None):
|
52 |
+
tanHalfFovY = math.tan((fovY / 2))
|
53 |
+
tanHalfFovX = math.tan((fovX / 2))
|
54 |
+
|
55 |
+
top = tanHalfFovY * znear
|
56 |
+
bottom = -top
|
57 |
+
right = tanHalfFovX * znear
|
58 |
+
left = -right
|
59 |
+
|
60 |
+
frustum_width = right - left
|
61 |
+
frustum_height = top - bottom
|
62 |
+
|
63 |
+
if crop_box is not None:
|
64 |
+
assert width is not None and height is not None
|
65 |
+
x, y, w, h = crop_box
|
66 |
+
left = left + x / width * frustum_width
|
67 |
+
right = left + w / width * frustum_width
|
68 |
+
top = top - y / height * frustum_height
|
69 |
+
bottom = top - h / height * frustum_height
|
70 |
+
|
71 |
+
P = torch.zeros(4, 4)
|
72 |
+
|
73 |
+
z_sign = 1.0
|
74 |
+
|
75 |
+
P[0, 0] = 2.0 * znear / (right - left)
|
76 |
+
P[1, 1] = 2.0 * znear / (top - bottom)
|
77 |
+
P[0, 2] = (right + left) / (right - left)
|
78 |
+
P[1, 2] = (top + bottom) / (top - bottom)
|
79 |
+
P[3, 2] = z_sign
|
80 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
81 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
82 |
+
return P
|
83 |
+
|
84 |
+
def fov2focal(fov, pixels):
|
85 |
+
return pixels / (2 * math.tan(fov / 2))
|
86 |
+
|
87 |
+
def focal2fov(focal, pixels):
|
88 |
+
return 2*math.atan(pixels/(2*focal))
|
utils/scene/utils/image_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def mse(img1, img2):
|
15 |
+
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
16 |
+
|
17 |
+
def psnr(img1, img2):
|
18 |
+
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
|
19 |
+
return 20 * torch.log10(1.0 / torch.sqrt(mse))
|
utils/scene/utils/loss_utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.autograd import Variable
|
15 |
+
from math import exp
|
16 |
+
|
17 |
+
def l1_loss(network_output, gt, reduce=True):
|
18 |
+
l1_loss = torch.abs((network_output - gt))
|
19 |
+
return l1_loss.mean() if reduce else l1_loss
|
20 |
+
|
21 |
+
def l2_loss(network_output, gt):
|
22 |
+
return ((network_output - gt) ** 2).mean()
|
23 |
+
|
24 |
+
def gaussian(window_size, sigma):
|
25 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
26 |
+
return gauss / gauss.sum()
|
27 |
+
|
28 |
+
def create_window(window_size, channel):
|
29 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
30 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
31 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
32 |
+
return window
|
33 |
+
|
34 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
35 |
+
channel = img1.size(-3)
|
36 |
+
window = create_window(window_size, channel)
|
37 |
+
|
38 |
+
if img1.is_cuda:
|
39 |
+
window = window.cuda(img1.get_device())
|
40 |
+
window = window.type_as(img1)
|
41 |
+
|
42 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
43 |
+
|
44 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
45 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
46 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
47 |
+
|
48 |
+
mu1_sq = mu1.pow(2)
|
49 |
+
mu2_sq = mu2.pow(2)
|
50 |
+
mu1_mu2 = mu1 * mu2
|
51 |
+
|
52 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
53 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
54 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
55 |
+
|
56 |
+
C1 = 0.01 ** 2
|
57 |
+
C2 = 0.03 ** 2
|
58 |
+
|
59 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
60 |
+
|
61 |
+
if size_average:
|
62 |
+
return ssim_map.mean()
|
63 |
+
else:
|
64 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
65 |
+
|
utils/scene/utils/sh_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The PlenOctree Authors.
|
2 |
+
# Redistribution and use in source and binary forms, with or without
|
3 |
+
# modification, are permitted provided that the following conditions are met:
|
4 |
+
#
|
5 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
6 |
+
# this list of conditions and the following disclaimer.
|
7 |
+
#
|
8 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
9 |
+
# this list of conditions and the following disclaimer in the documentation
|
10 |
+
# and/or other materials provided with the distribution.
|
11 |
+
#
|
12 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
13 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
14 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
15 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
16 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
17 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
18 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
19 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
20 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
21 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
22 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
C0 = 0.28209479177387814
|
27 |
+
C1 = 0.4886025119029199
|
28 |
+
C2 = [
|
29 |
+
1.0925484305920792,
|
30 |
+
-1.0925484305920792,
|
31 |
+
0.31539156525252005,
|
32 |
+
-1.0925484305920792,
|
33 |
+
0.5462742152960396
|
34 |
+
]
|
35 |
+
C3 = [
|
36 |
+
-0.5900435899266435,
|
37 |
+
2.890611442640554,
|
38 |
+
-0.4570457994644658,
|
39 |
+
0.3731763325901154,
|
40 |
+
-0.4570457994644658,
|
41 |
+
1.445305721320277,
|
42 |
+
-0.5900435899266435
|
43 |
+
]
|
44 |
+
C4 = [
|
45 |
+
2.5033429417967046,
|
46 |
+
-1.7701307697799304,
|
47 |
+
0.9461746957575601,
|
48 |
+
-0.6690465435572892,
|
49 |
+
0.10578554691520431,
|
50 |
+
-0.6690465435572892,
|
51 |
+
0.47308734787878004,
|
52 |
+
-1.7701307697799304,
|
53 |
+
0.6258357354491761,
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def eval_sh(deg, sh, dirs):
|
58 |
+
"""
|
59 |
+
Evaluate spherical harmonics at unit directions
|
60 |
+
using hardcoded SH polynomials.
|
61 |
+
Works with torch/np/jnp.
|
62 |
+
... Can be 0 or more batch dimensions.
|
63 |
+
Args:
|
64 |
+
deg: int SH deg. Currently, 0-3 supported
|
65 |
+
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
66 |
+
dirs: jnp.ndarray unit directions [..., 3]
|
67 |
+
Returns:
|
68 |
+
[..., C]
|
69 |
+
"""
|
70 |
+
assert deg <= 4 and deg >= 0
|
71 |
+
coeff = (deg + 1) ** 2
|
72 |
+
assert sh.shape[-1] >= coeff
|
73 |
+
|
74 |
+
result = C0 * sh[..., 0]
|
75 |
+
if deg > 0:
|
76 |
+
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
77 |
+
result = (result -
|
78 |
+
C1 * y * sh[..., 1] +
|
79 |
+
C1 * z * sh[..., 2] -
|
80 |
+
C1 * x * sh[..., 3])
|
81 |
+
|
82 |
+
if deg > 1:
|
83 |
+
xx, yy, zz = x * x, y * y, z * z
|
84 |
+
xy, yz, xz = x * y, y * z, x * z
|
85 |
+
result = (result +
|
86 |
+
C2[0] * xy * sh[..., 4] +
|
87 |
+
C2[1] * yz * sh[..., 5] +
|
88 |
+
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
89 |
+
C2[3] * xz * sh[..., 7] +
|
90 |
+
C2[4] * (xx - yy) * sh[..., 8])
|
91 |
+
|
92 |
+
if deg > 2:
|
93 |
+
result = (result +
|
94 |
+
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
95 |
+
C3[1] * xy * z * sh[..., 10] +
|
96 |
+
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
97 |
+
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
98 |
+
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
99 |
+
C3[5] * z * (xx - yy) * sh[..., 14] +
|
100 |
+
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
101 |
+
|
102 |
+
if deg > 3:
|
103 |
+
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
104 |
+
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
105 |
+
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
106 |
+
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
107 |
+
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
108 |
+
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
109 |
+
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
110 |
+
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
111 |
+
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
112 |
+
return result
|
113 |
+
|
114 |
+
def RGB2SH(rgb):
|
115 |
+
return (rgb - 0.5) / C0
|
116 |
+
|
117 |
+
def SH2RGB(sh):
|
118 |
+
return sh * C0 + 0.5
|
utils/scene/utils/system_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact [email protected]
|
10 |
+
#
|
11 |
+
|
12 |
+
from errno import EEXIST
|
13 |
+
from os import makedirs, path
|
14 |
+
import os
|
15 |
+
|
16 |
+
def mkdir_p(folder_path):
|
17 |
+
# Creates a directory. equivalent to using mkdir -p on the command line
|
18 |
+
try:
|
19 |
+
makedirs(folder_path)
|
20 |
+
except OSError as exc: # Python >2.5
|
21 |
+
if exc.errno == EEXIST and path.isdir(folder_path):
|
22 |
+
pass
|
23 |
+
else:
|
24 |
+
raise
|
25 |
+
|
26 |
+
def searchForMaxIteration(folder):
|
27 |
+
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
|
28 |
+
return max(saved_iters)
|
zoedepth/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
zoedepth/data/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
zoedepth/data/data_mono.py
ADDED
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
|
26 |
+
|
27 |
+
import itertools
|
28 |
+
import os
|
29 |
+
import random
|
30 |
+
from random import choice
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
import cv2
|
34 |
+
import torch
|
35 |
+
import torch.nn as nn
|
36 |
+
import torch.utils.data.distributed
|
37 |
+
from zoedepth.utils.easydict import EasyDict as edict
|
38 |
+
from PIL import Image, ImageOps
|
39 |
+
from torch.utils.data import DataLoader, Dataset
|
40 |
+
from torchvision import transforms
|
41 |
+
|
42 |
+
from zoedepth.utils.config import change_dataset
|
43 |
+
|
44 |
+
from .ddad import get_ddad_loader
|
45 |
+
from .diml_indoor_test import get_diml_indoor_loader
|
46 |
+
from .diml_outdoor_test import get_diml_outdoor_loader
|
47 |
+
from .diode import get_diode_loader
|
48 |
+
from .hypersim import get_hypersim_loader
|
49 |
+
from .ibims import get_ibims_loader
|
50 |
+
from .sun_rgbd_loader import get_sunrgbd_loader
|
51 |
+
from .vkitti import get_vkitti_loader
|
52 |
+
from .vkitti2 import get_vkitti2_loader
|
53 |
+
from .places365 import get_places365_loader, Places365
|
54 |
+
from .marigold_nyu import get_marigold_nyu_loader, MarigoldNYU
|
55 |
+
|
56 |
+
from .preprocess import CropParams, get_white_border, get_black_border
|
57 |
+
|
58 |
+
|
59 |
+
def _is_pil_image(img):
|
60 |
+
return isinstance(img, Image.Image)
|
61 |
+
|
62 |
+
|
63 |
+
def _is_numpy_image(img):
|
64 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
65 |
+
|
66 |
+
|
67 |
+
def preprocessing_transforms(mode, **kwargs):
|
68 |
+
return transforms.Compose([
|
69 |
+
ToTensor(mode=mode, **kwargs)
|
70 |
+
])
|
71 |
+
|
72 |
+
|
73 |
+
class DepthDataLoader(object):
|
74 |
+
def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
|
75 |
+
"""
|
76 |
+
Data loader for depth datasets
|
77 |
+
|
78 |
+
Args:
|
79 |
+
config (dict): Config dictionary. Refer to utils/config.py
|
80 |
+
mode (str): "train" or "online_eval"
|
81 |
+
device (str, optional): Device to load the data on. Defaults to 'cpu'.
|
82 |
+
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
|
83 |
+
"""
|
84 |
+
|
85 |
+
self.config = config
|
86 |
+
|
87 |
+
if config.dataset == 'ibims':
|
88 |
+
self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
|
89 |
+
return
|
90 |
+
|
91 |
+
if config.dataset == 'sunrgbd':
|
92 |
+
self.data = get_sunrgbd_loader(
|
93 |
+
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
|
94 |
+
return
|
95 |
+
|
96 |
+
if config.dataset == 'diml_indoor':
|
97 |
+
self.data = get_diml_indoor_loader(
|
98 |
+
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
|
99 |
+
return
|
100 |
+
|
101 |
+
if config.dataset == 'diml_outdoor':
|
102 |
+
self.data = get_diml_outdoor_loader(
|
103 |
+
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
|
104 |
+
return
|
105 |
+
|
106 |
+
if "diode" in config.dataset:
|
107 |
+
self.data = get_diode_loader(
|
108 |
+
config[config.dataset+"_root"], batch_size=1, num_workers=1)
|
109 |
+
return
|
110 |
+
|
111 |
+
if config.dataset == 'hypersim_test':
|
112 |
+
self.data = get_hypersim_loader(
|
113 |
+
config.hypersim_test_root, batch_size=1, num_workers=1)
|
114 |
+
return
|
115 |
+
|
116 |
+
if config.dataset == 'vkitti':
|
117 |
+
self.data = get_vkitti_loader(
|
118 |
+
config.vkitti_root, batch_size=1, num_workers=1)
|
119 |
+
return
|
120 |
+
|
121 |
+
if config.dataset == 'vkitti2':
|
122 |
+
self.data = get_vkitti2_loader(
|
123 |
+
config.vkitti2_root, batch_size=1, num_workers=1)
|
124 |
+
return
|
125 |
+
|
126 |
+
if config.dataset == 'ddad':
|
127 |
+
self.data = get_ddad_loader(config.ddad_root, resize_shape=(
|
128 |
+
352, 1216), batch_size=1, num_workers=1)
|
129 |
+
return
|
130 |
+
|
131 |
+
img_size = self.config.get("img_size", None)
|
132 |
+
img_size = img_size if self.config.get(
|
133 |
+
"do_input_resize", False) else None
|
134 |
+
|
135 |
+
if transform is None:
|
136 |
+
transform = preprocessing_transforms(mode, size=img_size)
|
137 |
+
|
138 |
+
if mode == 'train':
|
139 |
+
|
140 |
+
Dataset = DataLoadPreprocess
|
141 |
+
self.training_samples = Dataset(
|
142 |
+
config, mode, transform=transform, device=device)
|
143 |
+
|
144 |
+
if config.distributed and not config.debug_mode:
|
145 |
+
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
146 |
+
self.training_samples)
|
147 |
+
else:
|
148 |
+
self.train_sampler = None
|
149 |
+
|
150 |
+
if not config.debug_mode:
|
151 |
+
self.data = DataLoader(self.training_samples,
|
152 |
+
batch_size=config.batch_size,
|
153 |
+
shuffle=(self.train_sampler is None),
|
154 |
+
num_workers=config.workers,
|
155 |
+
pin_memory=True,
|
156 |
+
persistent_workers=True,
|
157 |
+
# prefetch_factor=2,
|
158 |
+
sampler=self.train_sampler)
|
159 |
+
else:
|
160 |
+
self.data = DataLoader(self.training_samples,
|
161 |
+
batch_size=config.batch_size,
|
162 |
+
shuffle=(self.train_sampler is None),
|
163 |
+
num_workers=0,
|
164 |
+
pin_memory=True,
|
165 |
+
# prefetch_factor=2,
|
166 |
+
sampler=self.train_sampler)
|
167 |
+
|
168 |
+
elif mode == 'online_eval':
|
169 |
+
self.testing_samples = DataLoadPreprocess(
|
170 |
+
config, mode, transform=transform)
|
171 |
+
if config.distributed: # redundant. here only for readability and to be more explicit
|
172 |
+
# Give whole test set to all processes (and report evaluation only on one) regardless
|
173 |
+
self.eval_sampler = None
|
174 |
+
else:
|
175 |
+
self.eval_sampler = None
|
176 |
+
self.data = DataLoader(self.testing_samples, 1,
|
177 |
+
shuffle=kwargs.get("shuffle_test", False),
|
178 |
+
num_workers=1,
|
179 |
+
pin_memory=False,
|
180 |
+
sampler=self.eval_sampler)
|
181 |
+
|
182 |
+
elif mode == 'test':
|
183 |
+
self.testing_samples = DataLoadPreprocess(
|
184 |
+
config, mode, transform=transform)
|
185 |
+
self.data = DataLoader(self.testing_samples,
|
186 |
+
1, shuffle=False, num_workers=1)
|
187 |
+
|
188 |
+
else:
|
189 |
+
print(
|
190 |
+
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
191 |
+
|
192 |
+
|
193 |
+
def repetitive_roundrobin(*iterables):
|
194 |
+
"""
|
195 |
+
cycles through iterables but sample wise
|
196 |
+
first yield first sample from first iterable then first sample from second iterable and so on
|
197 |
+
then second sample from first iterable then second sample from second iterable and so on
|
198 |
+
|
199 |
+
If one iterable is shorter than the others, it is repeated until all iterables are exhausted
|
200 |
+
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
|
201 |
+
"""
|
202 |
+
# Repetitive roundrobin
|
203 |
+
iterables_ = [iter(it) for it in iterables]
|
204 |
+
exhausted = [False] * len(iterables)
|
205 |
+
while not all(exhausted):
|
206 |
+
for i, it in enumerate(iterables_):
|
207 |
+
try:
|
208 |
+
yield next(it)
|
209 |
+
except StopIteration:
|
210 |
+
exhausted[i] = True
|
211 |
+
iterables_[i] = itertools.cycle(iterables[i])
|
212 |
+
# First elements may get repeated if one iterable is shorter than the others
|
213 |
+
yield next(iterables_[i])
|
214 |
+
|
215 |
+
|
216 |
+
class RepetitiveRoundRobinDataLoader(object):
|
217 |
+
def __init__(self, *dataloaders):
|
218 |
+
self.dataloaders = dataloaders
|
219 |
+
|
220 |
+
def __iter__(self):
|
221 |
+
return repetitive_roundrobin(*self.dataloaders)
|
222 |
+
|
223 |
+
def __len__(self):
|
224 |
+
# First samples get repeated, thats why the plus one
|
225 |
+
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
|
226 |
+
|
227 |
+
|
228 |
+
class MixedNYUKITTI(object):
|
229 |
+
def __init__(self, config, mode, device='cpu', **kwargs):
|
230 |
+
config = edict(config)
|
231 |
+
config.workers = config.workers // 2
|
232 |
+
self.config = config
|
233 |
+
nyu_conf = change_dataset(edict(config), 'nyu')
|
234 |
+
kitti_conf = change_dataset(edict(config), 'kitti')
|
235 |
+
|
236 |
+
# make nyu default for testing
|
237 |
+
self.config = config = nyu_conf
|
238 |
+
img_size = self.config.get("img_size", None)
|
239 |
+
img_size = img_size if self.config.get(
|
240 |
+
"do_input_resize", False) else None
|
241 |
+
if mode == 'train':
|
242 |
+
nyu_loader = DepthDataLoader(
|
243 |
+
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
244 |
+
kitti_loader = DepthDataLoader(
|
245 |
+
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
246 |
+
# It has been changed to repetitive roundrobin
|
247 |
+
self.data = RepetitiveRoundRobinDataLoader(
|
248 |
+
nyu_loader, kitti_loader)
|
249 |
+
else:
|
250 |
+
self.data = DepthDataLoader(nyu_conf, mode, device=device).data
|
251 |
+
|
252 |
+
class MixedNYUPlaces365(object):
|
253 |
+
def __init__(self, config, mode, device='cpu', **kwargs):
|
254 |
+
config = edict(config)
|
255 |
+
config.workers = config.workers // 2
|
256 |
+
self.config = config
|
257 |
+
nyu_conf = change_dataset(edict(config), 'nyu')
|
258 |
+
places365_conf = change_dataset(edict(config), 'places365')
|
259 |
+
|
260 |
+
# make nyu default for testing
|
261 |
+
self.config = config = nyu_conf
|
262 |
+
img_size = self.config.get("img_size", None)
|
263 |
+
img_size = img_size if self.config.get(
|
264 |
+
"do_input_resize", False) else None
|
265 |
+
if mode == 'train':
|
266 |
+
nyu_loader = DepthDataLoader(
|
267 |
+
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
268 |
+
places365_loader = DepthDataLoader(
|
269 |
+
places365_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
270 |
+
# It has been changed to repetitive roundrobin
|
271 |
+
self.data = RepetitiveRoundRobinDataLoader(
|
272 |
+
nyu_loader, places365_loader)
|
273 |
+
else:
|
274 |
+
self.data = DepthDataLoader(nyu_conf, mode, device=device).data
|
275 |
+
|
276 |
+
def remove_leading_slash(s):
|
277 |
+
if s[0] == '/' or s[0] == '\\':
|
278 |
+
return s[1:]
|
279 |
+
return s
|
280 |
+
|
281 |
+
|
282 |
+
class CachedReader:
|
283 |
+
def __init__(self, shared_dict=None):
|
284 |
+
if shared_dict:
|
285 |
+
self._cache = shared_dict
|
286 |
+
else:
|
287 |
+
self._cache = {}
|
288 |
+
|
289 |
+
def open(self, fpath):
|
290 |
+
im = self._cache.get(fpath, None)
|
291 |
+
if im is None:
|
292 |
+
im = self._cache[fpath] = Image.open(fpath)
|
293 |
+
return im
|
294 |
+
|
295 |
+
|
296 |
+
class ImReader:
|
297 |
+
def __init__(self):
|
298 |
+
pass
|
299 |
+
|
300 |
+
# @cache
|
301 |
+
def open(self, fpath):
|
302 |
+
return Image.open(fpath)
|
303 |
+
|
304 |
+
|
305 |
+
class DataLoadPreprocess(Dataset):
|
306 |
+
def __init__(self, config, mode, transform=None, is_for_online_eval=False, device="cpu", **kwargs):
|
307 |
+
self.config = config
|
308 |
+
if mode == 'online_eval':
|
309 |
+
with open(config.filenames_file_eval, 'r') as f:
|
310 |
+
self.filenames = f.readlines()
|
311 |
+
else:
|
312 |
+
with open(config.filenames_file, 'r') as f:
|
313 |
+
self.filenames = f.readlines()
|
314 |
+
|
315 |
+
self.device = torch.device(device)
|
316 |
+
self.mode = mode
|
317 |
+
self.transform = transform
|
318 |
+
self.to_tensor = ToTensor(mode)
|
319 |
+
self.is_for_online_eval = is_for_online_eval
|
320 |
+
if config.use_shared_dict:
|
321 |
+
self.reader = CachedReader(config.shared_dict)
|
322 |
+
else:
|
323 |
+
self.reader = ImReader()
|
324 |
+
|
325 |
+
if config.dataset == "places365" or config.inpaint_task_probability > 0:
|
326 |
+
places365_conf = change_dataset(edict(config), 'places365')
|
327 |
+
self.places365_data = self.data = Places365(places365_conf.places365_root, places365_conf.places365_depth_root, places365_conf.places365_depth_masks_root, randomize_masks=places365_conf.get("randomize_masks", True), debug_mode=self.config.debug_mode)
|
328 |
+
|
329 |
+
if config.dataset == "marigold_nyu":
|
330 |
+
self.marigold_data = self.data = MarigoldNYU(config.nyu_dir_root, config.marigold_depth_root, debug_mode=self.config.debug_mode)
|
331 |
+
self.config.avoid_boundary = True
|
332 |
+
|
333 |
+
def postprocess(self, sample):
|
334 |
+
return sample
|
335 |
+
|
336 |
+
def __getitem__(self, idx):
|
337 |
+
sample_path = self.filenames[idx] if self.config.dataset not in ('places365', "marigold_nyu") else self.filenames[0]
|
338 |
+
focal = float(sample_path.split()[2])
|
339 |
+
sample = {}
|
340 |
+
|
341 |
+
if self.mode == 'train':
|
342 |
+
depth_mask = None
|
343 |
+
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
|
344 |
+
image_path = os.path.join(
|
345 |
+
self.config.data_path, remove_leading_slash(sample_path.split()[3]))
|
346 |
+
depth_path = os.path.join(
|
347 |
+
self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
|
348 |
+
|
349 |
+
image = self.reader.open(image_path)
|
350 |
+
depth_gt = self.reader.open(depth_path)
|
351 |
+
w, h = image.size
|
352 |
+
|
353 |
+
elif self.config.dataset == 'places365':
|
354 |
+
image, depth_gt, depth_mask, image_path, depth_path, _ = self.places365_data[idx]
|
355 |
+
h, w = image.shape[:2]
|
356 |
+
|
357 |
+
if image.ndim == 2:
|
358 |
+
image = image.reshape(image.shape[0], image.shape[1], 1)
|
359 |
+
image = np.repeat(image, 3, axis=-1)
|
360 |
+
|
361 |
+
elif self.config.dataset == 'marigold_nyu':
|
362 |
+
image, depth_gt, marigold_gt, image_path, depth_path = self.marigold_data[idx]
|
363 |
+
|
364 |
+
h, w = image.shape[:2]
|
365 |
+
|
366 |
+
if image.ndim == 2:
|
367 |
+
image = image.reshape(image.shape[0], image.shape[1], 1)
|
368 |
+
image = np.repeat(image, 3, axis=-1)
|
369 |
+
|
370 |
+
else:
|
371 |
+
image_path = os.path.join(
|
372 |
+
self.config.data_path, remove_leading_slash(sample_path.split()[0]))
|
373 |
+
depth_path = os.path.join(
|
374 |
+
self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
|
375 |
+
|
376 |
+
image = self.reader.open(image_path)
|
377 |
+
depth_gt = self.reader.open(depth_path)
|
378 |
+
w, h = image.size
|
379 |
+
|
380 |
+
if self.config.inpaint_task_probability > 0:
|
381 |
+
_, _, depth_mask, _, _, _ = self.places365_data[idx]
|
382 |
+
|
383 |
+
if self.config.do_kb_crop:
|
384 |
+
height = image.height
|
385 |
+
width = image.width
|
386 |
+
top_margin = int(height - 352)
|
387 |
+
left_margin = int((width - 1216) / 2)
|
388 |
+
depth_gt = depth_gt.crop(
|
389 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
390 |
+
image = image.crop(
|
391 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
392 |
+
|
393 |
+
# Avoid blank boundaries due to pixel registration?
|
394 |
+
# Train images have white border. Test images have black border.
|
395 |
+
if self.config.dataset in ('nyu', 'marigold_nyu') and self.config.avoid_boundary:
|
396 |
+
# print("Avoiding Blank Boundaries!")
|
397 |
+
# We just crop and pad again with reflect padding to original size
|
398 |
+
# original_size = image.size
|
399 |
+
#crop_params = get_white_border(np.array(255*image, dtype=np.uint8))
|
400 |
+
# crop image down from 640x480 to 624x464
|
401 |
+
crop_params = CropParams(8, 472, 8, 632)
|
402 |
+
|
403 |
+
image = image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
|
404 |
+
depth_gt = depth_gt[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
|
405 |
+
|
406 |
+
# Use reflect padding to fill the blank
|
407 |
+
#image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
|
408 |
+
#image = Image.fromarray(image)
|
409 |
+
|
410 |
+
#depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), 'constant', constant_values=0)
|
411 |
+
#depth_gt = Image.fromarray(depth_gt)
|
412 |
+
|
413 |
+
if self.config.dataset == "marigold_nyu":
|
414 |
+
marigold_gt = marigold_gt[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
|
415 |
+
|
416 |
+
if self.config.do_random_rotate and (self.config.aug) and self.config.dataset not in ('places365', "marigold_nyu"):
|
417 |
+
random_angle = (random.random() - 0.5) * 2 * self.config.degree
|
418 |
+
image = self.rotate_image(image, random_angle)
|
419 |
+
depth_gt = self.rotate_image(
|
420 |
+
depth_gt, random_angle, flag=Image.NEAREST)
|
421 |
+
|
422 |
+
if self.config.dataset not in ('places365', "marigold_nyu"):
|
423 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
424 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
425 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
426 |
+
|
427 |
+
if self.config.dataset in ('nyu', 'marigold_nyu'):
|
428 |
+
depth_gt = depth_gt / 1000.0
|
429 |
+
elif self.config.dataset != 'places365':
|
430 |
+
depth_gt = depth_gt / 256.0
|
431 |
+
|
432 |
+
if self.config.aug and (self.config.random_crop) and self.config.dataset not in ('places365', "marigold_nyu"):
|
433 |
+
image, depth_gt = self.random_crop(
|
434 |
+
image, depth_gt, self.config.input_height, self.config.input_width)
|
435 |
+
|
436 |
+
if self.config.aug and self.config.random_translate and self.config.dataset not in ('places365', "marigold_nyu"):
|
437 |
+
# print("Random Translation!")
|
438 |
+
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
|
439 |
+
|
440 |
+
mask = np.logical_and(depth_gt > self.config.min_depth,
|
441 |
+
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
442 |
+
|
443 |
+
is_inpainting_sample = self.config.inpaint_task_probability > 0 and (torch.rand(1).item() < self.config.inpaint_task_probability)
|
444 |
+
|
445 |
+
def randomly_scale_depth(depth_to_scale):
|
446 |
+
# scale the mask
|
447 |
+
max_scale_factor = self.config.max_depth / depth_to_scale.max()
|
448 |
+
min_scale_factor = self.config.min_depth / depth_to_scale.min()
|
449 |
+
|
450 |
+
scale_factor = torch.rand(1).item() * (max_scale_factor - min_scale_factor) + min_scale_factor
|
451 |
+
scaled_depth = depth_to_scale * scale_factor
|
452 |
+
|
453 |
+
scaled_depth = scaled_depth.clip(self.config.min_depth, self.config.max_depth)
|
454 |
+
|
455 |
+
return scaled_depth
|
456 |
+
|
457 |
+
if self.config.dataset in ("marigold_nyu"):
|
458 |
+
marigold_mask = (marigold_gt > -1).squeeze()[None, ...]
|
459 |
+
|
460 |
+
if is_inpainting_sample and self.config.random_inpainting_scaling:
|
461 |
+
marigold_gt = randomly_scale_depth(marigold_gt)
|
462 |
+
|
463 |
+
marigold_gt[~marigold_mask[0]] = 0
|
464 |
+
|
465 |
+
depth_gt = marigold_gt
|
466 |
+
mask = marigold_mask
|
467 |
+
|
468 |
+
image, depth_gt, mask = self.train_preprocess(image, depth_gt, mask)
|
469 |
+
|
470 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal,
|
471 |
+
'mask': mask, **sample}
|
472 |
+
|
473 |
+
if self.config["depth_channel_mask_augment"]:
|
474 |
+
if self.config.dataset in ("marigold_nyu",):
|
475 |
+
if (not self.config.inpaint_task_probability > 0) and depth_mask is None:
|
476 |
+
depth_mask = np.zeros_like(depth_gt)
|
477 |
+
elif self.config.inpaint_task_probability > 0:
|
478 |
+
# we randomly mask with places365, or provide no sparse input at all
|
479 |
+
if is_inpainting_sample:
|
480 |
+
# upsample depth_mask to match depth_gt
|
481 |
+
depth_mask = torch.nn.functional.interpolate(torch.from_numpy(depth_mask).permute(2, 0, 1).unsqueeze(0), size=depth_gt.shape[:2], mode='nearest').squeeze(0).permute(1, 2, 0).numpy()
|
482 |
+
else:
|
483 |
+
depth_mask = np.zeros_like(depth_gt)
|
484 |
+
|
485 |
+
sample["masked_depth"] = depth_gt * depth_mask
|
486 |
+
|
487 |
+
else:
|
488 |
+
if self.mode == 'online_eval':
|
489 |
+
data_path = self.config.data_path_eval
|
490 |
+
else:
|
491 |
+
data_path = self.config.data_path
|
492 |
+
|
493 |
+
image_path = os.path.join(
|
494 |
+
data_path, remove_leading_slash(sample_path.split()[0]))
|
495 |
+
image = np.asarray(self.reader.open(image_path),
|
496 |
+
dtype=np.float32) / 255.0
|
497 |
+
|
498 |
+
if self.mode == 'online_eval':
|
499 |
+
gt_path = self.config.gt_path_eval
|
500 |
+
depth_path = os.path.join(
|
501 |
+
gt_path, remove_leading_slash(sample_path.split()[1]))
|
502 |
+
has_valid_depth = False
|
503 |
+
try:
|
504 |
+
depth_gt = self.reader.open(depth_path)
|
505 |
+
has_valid_depth = True
|
506 |
+
except IOError:
|
507 |
+
depth_gt = False
|
508 |
+
# print('Missing gt for {}'.format(image_path))
|
509 |
+
|
510 |
+
if has_valid_depth:
|
511 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
512 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
513 |
+
if self.config.dataset == 'nyu':
|
514 |
+
depth_gt = depth_gt / 1000.0
|
515 |
+
elif self.config.dataset != 'places365':
|
516 |
+
depth_gt = depth_gt / 256.0
|
517 |
+
|
518 |
+
mask = np.logical_and(
|
519 |
+
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
|
520 |
+
else:
|
521 |
+
mask = False
|
522 |
+
|
523 |
+
if self.config.do_kb_crop:
|
524 |
+
height = image.shape[0]
|
525 |
+
width = image.shape[1]
|
526 |
+
top_margin = int(height - 352)
|
527 |
+
left_margin = int((width - 1216) / 2)
|
528 |
+
image = image[top_margin:top_margin + 352,
|
529 |
+
left_margin:left_margin + 1216, :]
|
530 |
+
if self.mode == 'online_eval' and has_valid_depth:
|
531 |
+
depth_gt = depth_gt[top_margin:top_margin +
|
532 |
+
352, left_margin:left_margin + 1216, :]
|
533 |
+
|
534 |
+
if self.mode == 'online_eval':
|
535 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
536 |
+
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
|
537 |
+
'mask': mask}
|
538 |
+
else:
|
539 |
+
sample = {'image': image, 'focal': focal}
|
540 |
+
|
541 |
+
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
|
542 |
+
if (self.config.dataset not in ('places365', "marigold_nyu")):
|
543 |
+
mask = np.logical_and(depth_gt > self.config.min_depth,
|
544 |
+
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
545 |
+
sample['mask'] = mask
|
546 |
+
|
547 |
+
if self.transform:
|
548 |
+
sample = self.transform(sample)
|
549 |
+
|
550 |
+
sample = self.postprocess(sample)
|
551 |
+
sample['dataset'] = self.config.dataset
|
552 |
+
|
553 |
+
if self.config.dataset != 'places365':
|
554 |
+
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
|
555 |
+
else:
|
556 |
+
sample = {**sample, 'image_path': image_path, 'depth_path': depth_path}
|
557 |
+
|
558 |
+
return sample
|
559 |
+
|
560 |
+
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
561 |
+
result = image.rotate(angle, resample=flag)
|
562 |
+
return result
|
563 |
+
|
564 |
+
def random_crop(self, img, depth, height, width):
|
565 |
+
assert img.shape[0] >= height
|
566 |
+
assert img.shape[1] >= width
|
567 |
+
assert img.shape[0] == depth.shape[0]
|
568 |
+
assert img.shape[1] == depth.shape[1]
|
569 |
+
x = random.randint(0, img.shape[1] - width)
|
570 |
+
y = random.randint(0, img.shape[0] - height)
|
571 |
+
img = img[y:y + height, x:x + width, :]
|
572 |
+
depth = depth[y:y + height, x:x + width, :]
|
573 |
+
|
574 |
+
return img, depth
|
575 |
+
|
576 |
+
def random_translate(self, img, depth, max_t=20):
|
577 |
+
assert img.shape[0] == depth.shape[0]
|
578 |
+
assert img.shape[1] == depth.shape[1]
|
579 |
+
p = self.config.translate_prob
|
580 |
+
do_translate = random.random()
|
581 |
+
if do_translate > p:
|
582 |
+
return img, depth
|
583 |
+
x = random.randint(-max_t, max_t)
|
584 |
+
y = random.randint(-max_t, max_t)
|
585 |
+
M = np.float32([[1, 0, x], [0, 1, y]])
|
586 |
+
# print(img.shape, depth.shape)
|
587 |
+
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
|
588 |
+
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
|
589 |
+
depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
|
590 |
+
# print("after", img.shape, depth.shape)
|
591 |
+
return img, depth
|
592 |
+
|
593 |
+
def train_preprocess(self, image, depth_gt, mask):
|
594 |
+
if self.config.aug:
|
595 |
+
# Random flipping
|
596 |
+
do_flip = random.random()
|
597 |
+
if do_flip > 0.5:
|
598 |
+
# image is H x W x 3
|
599 |
+
image = (image[:, ::-1, :]).copy()
|
600 |
+
# depth_gt is H x W x 1
|
601 |
+
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
602 |
+
# mask is B x H x W
|
603 |
+
mask = (mask[:, :, ::-1]).copy()
|
604 |
+
|
605 |
+
# Random gamma, brightness, color augmentation
|
606 |
+
do_augment = random.random()
|
607 |
+
if do_augment > 0.5:
|
608 |
+
image = self.augment_image(image)
|
609 |
+
|
610 |
+
return image, depth_gt, mask
|
611 |
+
|
612 |
+
def augment_image(self, image):
|
613 |
+
# gamma augmentation
|
614 |
+
gamma = random.uniform(0.9, 1.1)
|
615 |
+
image_aug = image ** gamma
|
616 |
+
|
617 |
+
# brightness augmentation
|
618 |
+
if self.config.dataset == 'nyu':
|
619 |
+
brightness = random.uniform(0.75, 1.25)
|
620 |
+
else:
|
621 |
+
brightness = random.uniform(0.9, 1.1)
|
622 |
+
image_aug = image_aug * brightness
|
623 |
+
|
624 |
+
# color augmentation
|
625 |
+
colors = np.random.uniform(0.9, 1.1, size=3)
|
626 |
+
white = np.ones((image.shape[0], image.shape[1]))
|
627 |
+
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
628 |
+
image_aug *= color_image
|
629 |
+
image_aug = np.clip(image_aug, 0, 1)
|
630 |
+
|
631 |
+
return image_aug
|
632 |
+
|
633 |
+
def __len__(self):
|
634 |
+
return len(self.data) if (self.config.dataset in ('places365', "marigold_nyu") and self.mode != 'online_eval') else len(self.filenames)
|
635 |
+
|
636 |
+
|
637 |
+
class ToTensor(object):
|
638 |
+
def __init__(self, mode, do_normalize=False, size=None):
|
639 |
+
self.mode = mode
|
640 |
+
self.normalize = transforms.Normalize(
|
641 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
|
642 |
+
self.size = size
|
643 |
+
if size is not None:
|
644 |
+
self.resize = transforms.Resize(size=size)
|
645 |
+
else:
|
646 |
+
self.resize = nn.Identity()
|
647 |
+
|
648 |
+
def __call__(self, sample):
|
649 |
+
image, focal = sample['image'], sample['focal']
|
650 |
+
image = self.to_tensor(image)
|
651 |
+
image = self.normalize(image)
|
652 |
+
image = self.resize(image)
|
653 |
+
|
654 |
+
if self.mode == 'test':
|
655 |
+
return {'image': image, 'focal': focal}
|
656 |
+
|
657 |
+
depth = sample['depth']
|
658 |
+
if self.mode == 'train':
|
659 |
+
depth = self.to_tensor(depth)
|
660 |
+
return {**sample, 'image': image, 'depth': depth, 'focal': focal}
|
661 |
+
else:
|
662 |
+
has_valid_depth = sample['has_valid_depth']
|
663 |
+
image = self.resize(image)
|
664 |
+
return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
665 |
+
'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
|
666 |
+
|
667 |
+
def to_tensor(self, pic):
|
668 |
+
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
669 |
+
raise TypeError(
|
670 |
+
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
671 |
+
|
672 |
+
if isinstance(pic, np.ndarray):
|
673 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
674 |
+
return img
|
675 |
+
|
676 |
+
# handle PIL Image
|
677 |
+
if pic.mode == 'I':
|
678 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
679 |
+
elif pic.mode == 'I;16':
|
680 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
681 |
+
else:
|
682 |
+
img = torch.ByteTensor(
|
683 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
684 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
685 |
+
if pic.mode == 'YCbCr':
|
686 |
+
nchannel = 3
|
687 |
+
elif pic.mode == 'I;16':
|
688 |
+
nchannel = 1
|
689 |
+
else:
|
690 |
+
nchannel = len(pic.mode)
|
691 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
692 |
+
|
693 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
694 |
+
if isinstance(img, torch.ByteTensor):
|
695 |
+
return img.float()
|
696 |
+
else:
|
697 |
+
return img
|
zoedepth/data/ddad.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self, resize_shape):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
self.resize = transforms.Resize(resize_shape)
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
image = self.resize(image)
|
48 |
+
|
49 |
+
return {'image': image, 'depth': depth, 'dataset': "ddad"}
|
50 |
+
|
51 |
+
def to_tensor(self, pic):
|
52 |
+
|
53 |
+
if isinstance(pic, np.ndarray):
|
54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
55 |
+
return img
|
56 |
+
|
57 |
+
# # handle PIL Image
|
58 |
+
if pic.mode == 'I':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
60 |
+
elif pic.mode == 'I;16':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
62 |
+
else:
|
63 |
+
img = torch.ByteTensor(
|
64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
66 |
+
if pic.mode == 'YCbCr':
|
67 |
+
nchannel = 3
|
68 |
+
elif pic.mode == 'I;16':
|
69 |
+
nchannel = 1
|
70 |
+
else:
|
71 |
+
nchannel = len(pic.mode)
|
72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
73 |
+
|
74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
75 |
+
|
76 |
+
if isinstance(img, torch.ByteTensor):
|
77 |
+
return img.float()
|
78 |
+
else:
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
class DDAD(Dataset):
|
83 |
+
def __init__(self, data_dir_root, resize_shape):
|
84 |
+
import glob
|
85 |
+
|
86 |
+
# image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
|
87 |
+
self.image_files = glob.glob(os.path.join(data_dir_root, '*.png'))
|
88 |
+
self.depth_files = [r.replace("_rgb.png", "_depth.npy")
|
89 |
+
for r in self.image_files]
|
90 |
+
self.transform = ToTensor(resize_shape)
|
91 |
+
|
92 |
+
def __getitem__(self, idx):
|
93 |
+
|
94 |
+
image_path = self.image_files[idx]
|
95 |
+
depth_path = self.depth_files[idx]
|
96 |
+
|
97 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
98 |
+
depth = np.load(depth_path) # meters
|
99 |
+
|
100 |
+
# depth[depth > 8] = -1
|
101 |
+
depth = depth[..., None]
|
102 |
+
|
103 |
+
sample = dict(image=image, depth=depth)
|
104 |
+
sample = self.transform(sample)
|
105 |
+
|
106 |
+
if idx == 0:
|
107 |
+
print(sample["image"].shape)
|
108 |
+
|
109 |
+
return sample
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.image_files)
|
113 |
+
|
114 |
+
|
115 |
+
def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs):
|
116 |
+
dataset = DDAD(data_dir_root, resize_shape)
|
117 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
zoedepth/data/diml_indoor_test.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
self.resize = transforms.Resize((480, 640))
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
image = self.resize(image)
|
48 |
+
|
49 |
+
return {'image': image, 'depth': depth, 'dataset': "diml_indoor"}
|
50 |
+
|
51 |
+
def to_tensor(self, pic):
|
52 |
+
|
53 |
+
if isinstance(pic, np.ndarray):
|
54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
55 |
+
return img
|
56 |
+
|
57 |
+
# # handle PIL Image
|
58 |
+
if pic.mode == 'I':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
60 |
+
elif pic.mode == 'I;16':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
62 |
+
else:
|
63 |
+
img = torch.ByteTensor(
|
64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
66 |
+
if pic.mode == 'YCbCr':
|
67 |
+
nchannel = 3
|
68 |
+
elif pic.mode == 'I;16':
|
69 |
+
nchannel = 1
|
70 |
+
else:
|
71 |
+
nchannel = len(pic.mode)
|
72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
73 |
+
|
74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
75 |
+
if isinstance(img, torch.ByteTensor):
|
76 |
+
return img.float()
|
77 |
+
else:
|
78 |
+
return img
|
79 |
+
|
80 |
+
|
81 |
+
class DIML_Indoor(Dataset):
|
82 |
+
def __init__(self, data_dir_root):
|
83 |
+
import glob
|
84 |
+
|
85 |
+
# image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
|
86 |
+
self.image_files = glob.glob(os.path.join(
|
87 |
+
data_dir_root, "LR", '*', 'color', '*.png'))
|
88 |
+
self.depth_files = [r.replace("color", "depth_filled").replace(
|
89 |
+
"_c.png", "_depth_filled.png") for r in self.image_files]
|
90 |
+
self.transform = ToTensor()
|
91 |
+
|
92 |
+
def __getitem__(self, idx):
|
93 |
+
image_path = self.image_files[idx]
|
94 |
+
depth_path = self.depth_files[idx]
|
95 |
+
|
96 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
97 |
+
depth = np.asarray(Image.open(depth_path),
|
98 |
+
dtype='uint16') / 1000.0 # mm to meters
|
99 |
+
|
100 |
+
# print(np.shape(image))
|
101 |
+
# print(np.shape(depth))
|
102 |
+
|
103 |
+
# depth[depth > 8] = -1
|
104 |
+
depth = depth[..., None]
|
105 |
+
|
106 |
+
sample = dict(image=image, depth=depth)
|
107 |
+
|
108 |
+
# return sample
|
109 |
+
sample = self.transform(sample)
|
110 |
+
|
111 |
+
if idx == 0:
|
112 |
+
print(sample["image"].shape)
|
113 |
+
|
114 |
+
return sample
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.image_files)
|
118 |
+
|
119 |
+
|
120 |
+
def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs):
|
121 |
+
dataset = DIML_Indoor(data_dir_root)
|
122 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
123 |
+
|
124 |
+
# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR")
|
125 |
+
# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR")
|
zoedepth/data/diml_outdoor_test.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
|
40 |
+
def __call__(self, sample):
|
41 |
+
image, depth = sample['image'], sample['depth']
|
42 |
+
image = self.to_tensor(image)
|
43 |
+
image = self.normalize(image)
|
44 |
+
depth = self.to_tensor(depth)
|
45 |
+
|
46 |
+
return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"}
|
47 |
+
|
48 |
+
def to_tensor(self, pic):
|
49 |
+
|
50 |
+
if isinstance(pic, np.ndarray):
|
51 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
52 |
+
return img
|
53 |
+
|
54 |
+
# # handle PIL Image
|
55 |
+
if pic.mode == 'I':
|
56 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
57 |
+
elif pic.mode == 'I;16':
|
58 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
59 |
+
else:
|
60 |
+
img = torch.ByteTensor(
|
61 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
62 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
63 |
+
if pic.mode == 'YCbCr':
|
64 |
+
nchannel = 3
|
65 |
+
elif pic.mode == 'I;16':
|
66 |
+
nchannel = 1
|
67 |
+
else:
|
68 |
+
nchannel = len(pic.mode)
|
69 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
70 |
+
|
71 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
72 |
+
if isinstance(img, torch.ByteTensor):
|
73 |
+
return img.float()
|
74 |
+
else:
|
75 |
+
return img
|
76 |
+
|
77 |
+
|
78 |
+
class DIML_Outdoor(Dataset):
|
79 |
+
def __init__(self, data_dir_root):
|
80 |
+
import glob
|
81 |
+
|
82 |
+
# image paths are of the form <data_dir_root>/{outleft, depthmap}/*.png
|
83 |
+
self.image_files = glob.glob(os.path.join(
|
84 |
+
data_dir_root, "*", 'outleft', '*.png'))
|
85 |
+
self.depth_files = [r.replace("outleft", "depthmap")
|
86 |
+
for r in self.image_files]
|
87 |
+
self.transform = ToTensor()
|
88 |
+
|
89 |
+
def __getitem__(self, idx):
|
90 |
+
image_path = self.image_files[idx]
|
91 |
+
depth_path = self.depth_files[idx]
|
92 |
+
|
93 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
94 |
+
depth = np.asarray(Image.open(depth_path),
|
95 |
+
dtype='uint16') / 1000.0 # mm to meters
|
96 |
+
|
97 |
+
# depth[depth > 8] = -1
|
98 |
+
depth = depth[..., None]
|
99 |
+
|
100 |
+
sample = dict(image=image, depth=depth, dataset="diml_outdoor")
|
101 |
+
|
102 |
+
# return sample
|
103 |
+
return self.transform(sample)
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return len(self.image_files)
|
107 |
+
|
108 |
+
|
109 |
+
def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs):
|
110 |
+
dataset = DIML_Outdoor(data_dir_root)
|
111 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
112 |
+
|
113 |
+
# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR")
|
114 |
+
# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR")
|
zoedepth/data/diode.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
self.resize = transforms.Resize(480)
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
image = self.resize(image)
|
48 |
+
|
49 |
+
return {'image': image, 'depth': depth, 'dataset': "diode"}
|
50 |
+
|
51 |
+
def to_tensor(self, pic):
|
52 |
+
|
53 |
+
if isinstance(pic, np.ndarray):
|
54 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
55 |
+
return img
|
56 |
+
|
57 |
+
# # handle PIL Image
|
58 |
+
if pic.mode == 'I':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
60 |
+
elif pic.mode == 'I;16':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
62 |
+
else:
|
63 |
+
img = torch.ByteTensor(
|
64 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
65 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
66 |
+
if pic.mode == 'YCbCr':
|
67 |
+
nchannel = 3
|
68 |
+
elif pic.mode == 'I;16':
|
69 |
+
nchannel = 1
|
70 |
+
else:
|
71 |
+
nchannel = len(pic.mode)
|
72 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
73 |
+
|
74 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
75 |
+
|
76 |
+
if isinstance(img, torch.ByteTensor):
|
77 |
+
return img.float()
|
78 |
+
else:
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
class DIODE(Dataset):
|
83 |
+
def __init__(self, data_dir_root):
|
84 |
+
import glob
|
85 |
+
|
86 |
+
# image paths are of the form <data_dir_root>/scene_#/scan_#/*.png
|
87 |
+
self.image_files = glob.glob(
|
88 |
+
os.path.join(data_dir_root, '*', '*', '*.png'))
|
89 |
+
self.depth_files = [r.replace(".png", "_depth.npy")
|
90 |
+
for r in self.image_files]
|
91 |
+
self.depth_mask_files = [
|
92 |
+
r.replace(".png", "_depth_mask.npy") for r in self.image_files]
|
93 |
+
self.transform = ToTensor()
|
94 |
+
|
95 |
+
def __getitem__(self, idx):
|
96 |
+
image_path = self.image_files[idx]
|
97 |
+
depth_path = self.depth_files[idx]
|
98 |
+
depth_mask_path = self.depth_mask_files[idx]
|
99 |
+
|
100 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
101 |
+
depth = np.load(depth_path) # in meters
|
102 |
+
valid = np.load(depth_mask_path) # binary
|
103 |
+
|
104 |
+
# depth[depth > 8] = -1
|
105 |
+
# depth = depth[..., None]
|
106 |
+
|
107 |
+
sample = dict(image=image, depth=depth, valid=valid)
|
108 |
+
|
109 |
+
# return sample
|
110 |
+
sample = self.transform(sample)
|
111 |
+
|
112 |
+
if idx == 0:
|
113 |
+
print(sample["image"].shape)
|
114 |
+
|
115 |
+
return sample
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.image_files)
|
119 |
+
|
120 |
+
|
121 |
+
def get_diode_loader(data_dir_root, batch_size=1, **kwargs):
|
122 |
+
dataset = DIODE(data_dir_root)
|
123 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
124 |
+
|
125 |
+
# get_diode_loader(data_dir_root="datasets/diode/val/outdoor")
|
zoedepth/data/hypersim.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import glob
|
26 |
+
import os
|
27 |
+
|
28 |
+
import h5py
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from PIL import Image
|
32 |
+
from torch.utils.data import DataLoader, Dataset
|
33 |
+
from torchvision import transforms
|
34 |
+
|
35 |
+
|
36 |
+
def hypersim_distance_to_depth(npyDistance):
|
37 |
+
intWidth, intHeight, fltFocal = 1024, 768, 886.81
|
38 |
+
|
39 |
+
npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
|
40 |
+
1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
|
41 |
+
npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
|
42 |
+
intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
|
43 |
+
npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
|
44 |
+
npyImageplane = np.concatenate(
|
45 |
+
[npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
|
46 |
+
|
47 |
+
npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
|
48 |
+
return npyDepth
|
49 |
+
|
50 |
+
|
51 |
+
class ToTensor(object):
|
52 |
+
def __init__(self):
|
53 |
+
# self.normalize = transforms.Normalize(
|
54 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
55 |
+
self.normalize = lambda x: x
|
56 |
+
self.resize = transforms.Resize((480, 640))
|
57 |
+
|
58 |
+
def __call__(self, sample):
|
59 |
+
image, depth = sample['image'], sample['depth']
|
60 |
+
image = self.to_tensor(image)
|
61 |
+
image = self.normalize(image)
|
62 |
+
depth = self.to_tensor(depth)
|
63 |
+
|
64 |
+
image = self.resize(image)
|
65 |
+
|
66 |
+
return {'image': image, 'depth': depth, 'dataset': "hypersim"}
|
67 |
+
|
68 |
+
def to_tensor(self, pic):
|
69 |
+
|
70 |
+
if isinstance(pic, np.ndarray):
|
71 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
72 |
+
return img
|
73 |
+
|
74 |
+
# # handle PIL Image
|
75 |
+
if pic.mode == 'I':
|
76 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
77 |
+
elif pic.mode == 'I;16':
|
78 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
79 |
+
else:
|
80 |
+
img = torch.ByteTensor(
|
81 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
82 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
83 |
+
if pic.mode == 'YCbCr':
|
84 |
+
nchannel = 3
|
85 |
+
elif pic.mode == 'I;16':
|
86 |
+
nchannel = 1
|
87 |
+
else:
|
88 |
+
nchannel = len(pic.mode)
|
89 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
90 |
+
|
91 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
92 |
+
if isinstance(img, torch.ByteTensor):
|
93 |
+
return img.float()
|
94 |
+
else:
|
95 |
+
return img
|
96 |
+
|
97 |
+
|
98 |
+
class HyperSim(Dataset):
|
99 |
+
def __init__(self, data_dir_root):
|
100 |
+
# image paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.tonemap.jpg
|
101 |
+
# depth paths are of the form <data_dir_root>/<scene>/images/scene_cam_#_final_preview/*.depth_meters.hdf5
|
102 |
+
self.image_files = glob.glob(os.path.join(
|
103 |
+
data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg'))
|
104 |
+
self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace(
|
105 |
+
".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files]
|
106 |
+
self.transform = ToTensor()
|
107 |
+
|
108 |
+
def __getitem__(self, idx):
|
109 |
+
image_path = self.image_files[idx]
|
110 |
+
depth_path = self.depth_files[idx]
|
111 |
+
|
112 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
113 |
+
|
114 |
+
# depth from hdf5
|
115 |
+
depth_fd = h5py.File(depth_path, "r")
|
116 |
+
# in meters (Euclidean distance)
|
117 |
+
distance_meters = np.array(depth_fd['dataset'])
|
118 |
+
depth = hypersim_distance_to_depth(
|
119 |
+
distance_meters) # in meters (planar depth)
|
120 |
+
|
121 |
+
# depth[depth > 8] = -1
|
122 |
+
depth = depth[..., None]
|
123 |
+
|
124 |
+
sample = dict(image=image, depth=depth)
|
125 |
+
sample = self.transform(sample)
|
126 |
+
|
127 |
+
if idx == 0:
|
128 |
+
print(sample["image"].shape)
|
129 |
+
|
130 |
+
return sample
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.image_files)
|
134 |
+
|
135 |
+
|
136 |
+
def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs):
|
137 |
+
dataset = HyperSim(data_dir_root)
|
138 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
zoedepth/data/ibims.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms as T
|
32 |
+
|
33 |
+
|
34 |
+
class iBims(Dataset):
|
35 |
+
def __init__(self, config):
|
36 |
+
root_folder = config.ibims_root
|
37 |
+
with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f:
|
38 |
+
imglist = f.read().split()
|
39 |
+
|
40 |
+
samples = []
|
41 |
+
for basename in imglist:
|
42 |
+
img_path = os.path.join(root_folder, 'rgb', basename + ".png")
|
43 |
+
depth_path = os.path.join(root_folder, 'depth', basename + ".png")
|
44 |
+
valid_mask_path = os.path.join(
|
45 |
+
root_folder, 'mask_invalid', basename+".png")
|
46 |
+
transp_mask_path = os.path.join(
|
47 |
+
root_folder, 'mask_transp', basename+".png")
|
48 |
+
|
49 |
+
samples.append(
|
50 |
+
(img_path, depth_path, valid_mask_path, transp_mask_path))
|
51 |
+
|
52 |
+
self.samples = samples
|
53 |
+
# self.normalize = T.Normalize(
|
54 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
55 |
+
self.normalize = lambda x : x
|
56 |
+
|
57 |
+
def __getitem__(self, idx):
|
58 |
+
img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx]
|
59 |
+
|
60 |
+
img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0
|
61 |
+
depth = np.asarray(Image.open(depth_path),
|
62 |
+
dtype=np.uint16).astype('float')*50.0/65535
|
63 |
+
|
64 |
+
mask_valid = np.asarray(Image.open(valid_mask_path))
|
65 |
+
mask_transp = np.asarray(Image.open(transp_mask_path))
|
66 |
+
|
67 |
+
# depth = depth * mask_valid * mask_transp
|
68 |
+
depth = np.where(mask_valid * mask_transp, depth, -1)
|
69 |
+
|
70 |
+
img = torch.from_numpy(img).permute(2, 0, 1)
|
71 |
+
img = self.normalize(img)
|
72 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
73 |
+
return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims')
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.samples)
|
77 |
+
|
78 |
+
|
79 |
+
def get_ibims_loader(config, batch_size=1, **kwargs):
|
80 |
+
dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs)
|
81 |
+
return dataloader
|
zoedepth/data/marigold_nyu.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
from random import choice
|
33 |
+
|
34 |
+
|
35 |
+
class ToTensor(object):
|
36 |
+
def __init__(self):
|
37 |
+
self.normalize = transforms.Normalize(
|
38 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
39 |
+
#self.normalize = lambda x : x
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
return {'image': image, 'depth': depth, 'dataset': "marigold_nyu"}
|
48 |
+
|
49 |
+
def to_tensor(self, pic):
|
50 |
+
|
51 |
+
if isinstance(pic, np.ndarray):
|
52 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
53 |
+
return img
|
54 |
+
|
55 |
+
# # handle PIL Image
|
56 |
+
if pic.mode == 'I':
|
57 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
58 |
+
elif pic.mode == 'I;16':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
60 |
+
else:
|
61 |
+
img = torch.ByteTensor(
|
62 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
63 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
64 |
+
if pic.mode == 'YCbCr':
|
65 |
+
nchannel = 3
|
66 |
+
elif pic.mode == 'I;16':
|
67 |
+
nchannel = 1
|
68 |
+
else:
|
69 |
+
nchannel = len(pic.mode)
|
70 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
71 |
+
|
72 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
73 |
+
if isinstance(img, torch.ByteTensor):
|
74 |
+
return img.float()
|
75 |
+
else:
|
76 |
+
return img
|
77 |
+
|
78 |
+
|
79 |
+
class MarigoldNYU(Dataset):
|
80 |
+
def __init__(self, nyu_dir_root, marigold_depth_root, debug_mode=False):
|
81 |
+
import glob
|
82 |
+
import os
|
83 |
+
import itertools
|
84 |
+
|
85 |
+
categories = os.listdir(os.path.join(nyu_dir_root))
|
86 |
+
if debug_mode:
|
87 |
+
categories = categories[:2]
|
88 |
+
|
89 |
+
self.image_files = list(itertools.chain(*[glob.glob(os.path.join(nyu_dir_root, c, "rgb_*.jpg")) for c in categories]))
|
90 |
+
self.nyu_depth_files = [os.path.join(nyu_dir_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "png").replace("rgb", "sync_depth") for r in self.image_files]
|
91 |
+
self.marigold_depth_files = [os.path.join(marigold_depth_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "npy") for r in self.image_files]
|
92 |
+
|
93 |
+
self.transform = ToTensor()
|
94 |
+
|
95 |
+
def __getitem__(self, idx):
|
96 |
+
image_path = self.image_files[idx]
|
97 |
+
nyu_depth_path = self.nyu_depth_files[idx]
|
98 |
+
marigold_depth_path = self.marigold_depth_files[idx]
|
99 |
+
|
100 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
101 |
+
nyu_depth = np.asarray(Image.open(nyu_depth_path), dtype=np.float32)
|
102 |
+
marigold_depth = np.load(marigold_depth_path)
|
103 |
+
|
104 |
+
return image, nyu_depth[..., np.newaxis], marigold_depth[..., np.newaxis], image_path, nyu_depth_path
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.image_files)
|
108 |
+
|
109 |
+
|
110 |
+
def get_marigold_nyu_loader(nyu_dir_root, marigold_depth_root, batch_size=1, **kwargs):
|
111 |
+
dataset = MarigoldNYU(nyu_dir_root, marigold_depth_root)
|
112 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
zoedepth/data/places365.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
from random import choice
|
33 |
+
|
34 |
+
|
35 |
+
class ToTensor(object):
|
36 |
+
def __init__(self):
|
37 |
+
self.normalize = transforms.Normalize(
|
38 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
39 |
+
#self.normalize = lambda x : x
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
image = self.to_tensor(image)
|
44 |
+
image = self.normalize(image)
|
45 |
+
depth = self.to_tensor(depth)
|
46 |
+
|
47 |
+
return {'image': image, 'depth': depth, 'dataset': "places365"}
|
48 |
+
|
49 |
+
def to_tensor(self, pic):
|
50 |
+
|
51 |
+
if isinstance(pic, np.ndarray):
|
52 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
53 |
+
return img
|
54 |
+
|
55 |
+
# # handle PIL Image
|
56 |
+
if pic.mode == 'I':
|
57 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
58 |
+
elif pic.mode == 'I;16':
|
59 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
60 |
+
else:
|
61 |
+
img = torch.ByteTensor(
|
62 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
63 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
64 |
+
if pic.mode == 'YCbCr':
|
65 |
+
nchannel = 3
|
66 |
+
elif pic.mode == 'I;16':
|
67 |
+
nchannel = 1
|
68 |
+
else:
|
69 |
+
nchannel = len(pic.mode)
|
70 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
71 |
+
|
72 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
73 |
+
if isinstance(img, torch.ByteTensor):
|
74 |
+
return img.float()
|
75 |
+
else:
|
76 |
+
return img
|
77 |
+
|
78 |
+
|
79 |
+
class Places365(Dataset):
|
80 |
+
def __init__(self, data_dir_root, depth_dir_root, depth_masks_dir_root, randomize_masks=True, debug_mode=False):
|
81 |
+
import glob
|
82 |
+
import os
|
83 |
+
import itertools
|
84 |
+
|
85 |
+
categories = os.listdir(os.path.join(data_dir_root))
|
86 |
+
if debug_mode:
|
87 |
+
categories = categories[:2]
|
88 |
+
|
89 |
+
self.image_files = list(itertools.chain(*[glob.glob(os.path.join(data_dir_root, c, "*.jpg")) for c in categories]))
|
90 |
+
self.depth_files = [os.path.join(depth_dir_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "npy") for r in self.image_files]
|
91 |
+
self.depth_masks_files = [os.path.join(depth_masks_dir_root, os.path.join(*r.split("/")[-2:])).replace("jpg", "npy") for r in self.image_files]
|
92 |
+
|
93 |
+
self.randomize_masks = randomize_masks
|
94 |
+
|
95 |
+
self.transform = ToTensor()
|
96 |
+
|
97 |
+
def __getitem__(self, idx):
|
98 |
+
image_path = self.image_files[idx]
|
99 |
+
depth_path = self.depth_files[idx]
|
100 |
+
|
101 |
+
if not self.randomize_masks:
|
102 |
+
depth_masks_path = self.depth_masks_files[idx]
|
103 |
+
else:
|
104 |
+
depth_masks_path = choice(self.depth_masks_files)
|
105 |
+
|
106 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
107 |
+
depth = np.load(depth_path)
|
108 |
+
depth_mask = 1 - np.load(depth_masks_path)
|
109 |
+
|
110 |
+
return image, depth[..., np.newaxis], depth_mask[..., np.newaxis], image_path, depth_path, depth_masks_path
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return len(self.image_files)
|
114 |
+
|
115 |
+
|
116 |
+
def get_places365_loader(data_dir_root, depth_dir_root, depth_masks_dir_root, batch_size=1, **kwargs):
|
117 |
+
dataset = Places365(data_dir_root, depth_dir_root, depth_masks_dir_root)
|
118 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
zoedepth/data/preprocess.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
from dataclasses import dataclass
|
27 |
+
from typing import Tuple, List
|
28 |
+
|
29 |
+
# dataclass to store the crop parameters
|
30 |
+
@dataclass
|
31 |
+
class CropParams:
|
32 |
+
top: int
|
33 |
+
bottom: int
|
34 |
+
left: int
|
35 |
+
right: int
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams:
|
40 |
+
gray_image = np.mean(rgb_image, axis=channel_axis)
|
41 |
+
h, w = gray_image.shape
|
42 |
+
|
43 |
+
|
44 |
+
def num_value_pixels(arr):
|
45 |
+
return np.sum(np.abs(arr - value) < level_diff_threshold)
|
46 |
+
|
47 |
+
def is_above_tolerance(arr, total_pixels):
|
48 |
+
return (num_value_pixels(arr) / total_pixels) > tolerance
|
49 |
+
|
50 |
+
# Crop top border until number of value pixels become below tolerance
|
51 |
+
top = min_border
|
52 |
+
while is_above_tolerance(gray_image[top, :], w) and top < h-1:
|
53 |
+
top += 1
|
54 |
+
if top > cut_off:
|
55 |
+
break
|
56 |
+
|
57 |
+
# Crop bottom border until number of value pixels become below tolerance
|
58 |
+
bottom = h - min_border
|
59 |
+
while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0:
|
60 |
+
bottom -= 1
|
61 |
+
if h - bottom > cut_off:
|
62 |
+
break
|
63 |
+
|
64 |
+
# Crop left border until number of value pixels become below tolerance
|
65 |
+
left = min_border
|
66 |
+
while is_above_tolerance(gray_image[:, left], h) and left < w-1:
|
67 |
+
left += 1
|
68 |
+
if left > cut_off:
|
69 |
+
break
|
70 |
+
|
71 |
+
# Crop right border until number of value pixels become below tolerance
|
72 |
+
right = w - min_border
|
73 |
+
while is_above_tolerance(gray_image[:, right], h) and right > 0:
|
74 |
+
right -= 1
|
75 |
+
if w - right > cut_off:
|
76 |
+
break
|
77 |
+
|
78 |
+
|
79 |
+
return CropParams(top, bottom, left, right)
|
80 |
+
|
81 |
+
|
82 |
+
def get_white_border(rgb_image, value=255, **kwargs) -> CropParams:
|
83 |
+
"""Crops the white border of the RGB.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
rgb: RGB image, shape (H, W, 3).
|
87 |
+
Returns:
|
88 |
+
Crop parameters.
|
89 |
+
"""
|
90 |
+
if value == 255:
|
91 |
+
# assert range of values in rgb image is [0, 255]
|
92 |
+
assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]."
|
93 |
+
assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]."
|
94 |
+
elif value == 1:
|
95 |
+
# assert range of values in rgb image is [0, 1]
|
96 |
+
assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]."
|
97 |
+
|
98 |
+
return get_border_params(rgb_image, value=value, **kwargs)
|
99 |
+
|
100 |
+
def get_black_border(rgb_image, **kwargs) -> CropParams:
|
101 |
+
"""Crops the black border of the RGB.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
rgb: RGB image, shape (H, W, 3).
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Crop parameters.
|
108 |
+
"""
|
109 |
+
|
110 |
+
return get_border_params(rgb_image, value=0, **kwargs)
|
111 |
+
|
112 |
+
def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray:
|
113 |
+
"""Crops the image according to the crop parameters.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
image: RGB or depth image, shape (H, W, 3) or (H, W).
|
117 |
+
crop_params: Crop parameters.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Cropped image.
|
121 |
+
"""
|
122 |
+
return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right]
|
123 |
+
|
124 |
+
def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]:
|
125 |
+
"""Crops the images according to the crop parameters.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
images: RGB or depth images, shape (H, W, 3) or (H, W).
|
129 |
+
crop_params: Crop parameters.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Cropped images.
|
133 |
+
"""
|
134 |
+
return tuple(crop_image(image, crop_params) for image in images)
|
135 |
+
|
136 |
+
def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]:
|
137 |
+
"""Crops the white and black border of the RGB and depth images.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
rgb: RGB image, shape (H, W, 3). This image is used to determine the border.
|
141 |
+
other_images: The other images to crop according to the border of the RGB image.
|
142 |
+
Returns:
|
143 |
+
Cropped RGB and other images.
|
144 |
+
"""
|
145 |
+
# crop black border
|
146 |
+
crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
|
147 |
+
cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params)
|
148 |
+
|
149 |
+
# crop white border
|
150 |
+
crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold)
|
151 |
+
cropped_images = crop_images(*cropped_images, crop_params=crop_params)
|
152 |
+
|
153 |
+
return cropped_images
|
154 |
+
|
zoedepth/data/sun_rgbd_loader.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from PIL import Image
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
|
34 |
+
class ToTensor(object):
|
35 |
+
def __init__(self):
|
36 |
+
# self.normalize = transforms.Normalize(
|
37 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
38 |
+
self.normalize = lambda x : x
|
39 |
+
|
40 |
+
def __call__(self, sample):
|
41 |
+
image, depth = sample['image'], sample['depth']
|
42 |
+
image = self.to_tensor(image)
|
43 |
+
image = self.normalize(image)
|
44 |
+
depth = self.to_tensor(depth)
|
45 |
+
|
46 |
+
return {'image': image, 'depth': depth, 'dataset': "sunrgbd"}
|
47 |
+
|
48 |
+
def to_tensor(self, pic):
|
49 |
+
|
50 |
+
if isinstance(pic, np.ndarray):
|
51 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
52 |
+
return img
|
53 |
+
|
54 |
+
# # handle PIL Image
|
55 |
+
if pic.mode == 'I':
|
56 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
57 |
+
elif pic.mode == 'I;16':
|
58 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
59 |
+
else:
|
60 |
+
img = torch.ByteTensor(
|
61 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
62 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
63 |
+
if pic.mode == 'YCbCr':
|
64 |
+
nchannel = 3
|
65 |
+
elif pic.mode == 'I;16':
|
66 |
+
nchannel = 1
|
67 |
+
else:
|
68 |
+
nchannel = len(pic.mode)
|
69 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
70 |
+
|
71 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
72 |
+
if isinstance(img, torch.ByteTensor):
|
73 |
+
return img.float()
|
74 |
+
else:
|
75 |
+
return img
|
76 |
+
|
77 |
+
|
78 |
+
class SunRGBD(Dataset):
|
79 |
+
def __init__(self, data_dir_root):
|
80 |
+
# test_file_dirs = loadmat(train_test_file)['alltest'].squeeze()
|
81 |
+
# all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs]
|
82 |
+
# self.all_test = [os.path.join(data_dir_root, t) for t in all_test]
|
83 |
+
import glob
|
84 |
+
self.image_files = glob.glob(
|
85 |
+
os.path.join(data_dir_root, 'rgb', 'rgb', '*'))
|
86 |
+
self.depth_files = [
|
87 |
+
r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files]
|
88 |
+
self.transform = ToTensor()
|
89 |
+
|
90 |
+
def __getitem__(self, idx):
|
91 |
+
image_path = self.image_files[idx]
|
92 |
+
depth_path = self.depth_files[idx]
|
93 |
+
|
94 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
95 |
+
depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0
|
96 |
+
depth[depth > 8] = -1
|
97 |
+
depth = depth[..., None]
|
98 |
+
return self.transform(dict(image=image, depth=depth))
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.image_files)
|
102 |
+
|
103 |
+
|
104 |
+
def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs):
|
105 |
+
dataset = SunRGBD(data_dir_root)
|
106 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
zoedepth/data/transforms.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import math
|
26 |
+
import random
|
27 |
+
|
28 |
+
import cv2
|
29 |
+
import numpy as np
|
30 |
+
|
31 |
+
|
32 |
+
class RandomFliplr(object):
|
33 |
+
"""Horizontal flip of the sample with given probability.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, probability=0.5):
|
37 |
+
"""Init.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
probability (float, optional): Flip probability. Defaults to 0.5.
|
41 |
+
"""
|
42 |
+
self.__probability = probability
|
43 |
+
|
44 |
+
def __call__(self, sample):
|
45 |
+
prob = random.random()
|
46 |
+
|
47 |
+
if prob < self.__probability:
|
48 |
+
for k, v in sample.items():
|
49 |
+
if len(v.shape) >= 2:
|
50 |
+
sample[k] = np.fliplr(v).copy()
|
51 |
+
|
52 |
+
return sample
|
53 |
+
|
54 |
+
|
55 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
56 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
sample (dict): sample
|
60 |
+
size (tuple): image size
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
tuple: new size
|
64 |
+
"""
|
65 |
+
shape = list(sample["disparity"].shape)
|
66 |
+
|
67 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
68 |
+
return sample
|
69 |
+
|
70 |
+
scale = [0, 0]
|
71 |
+
scale[0] = size[0] / shape[0]
|
72 |
+
scale[1] = size[1] / shape[1]
|
73 |
+
|
74 |
+
scale = max(scale)
|
75 |
+
|
76 |
+
shape[0] = math.ceil(scale * shape[0])
|
77 |
+
shape[1] = math.ceil(scale * shape[1])
|
78 |
+
|
79 |
+
# resize
|
80 |
+
sample["image"] = cv2.resize(
|
81 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
82 |
+
)
|
83 |
+
|
84 |
+
sample["disparity"] = cv2.resize(
|
85 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
86 |
+
)
|
87 |
+
sample["mask"] = cv2.resize(
|
88 |
+
sample["mask"].astype(np.float32),
|
89 |
+
tuple(shape[::-1]),
|
90 |
+
interpolation=cv2.INTER_NEAREST,
|
91 |
+
)
|
92 |
+
sample["mask"] = sample["mask"].astype(bool)
|
93 |
+
|
94 |
+
return tuple(shape)
|
95 |
+
|
96 |
+
|
97 |
+
class RandomCrop(object):
|
98 |
+
"""Get a random crop of the sample with the given size (width, height).
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
width,
|
104 |
+
height,
|
105 |
+
resize_if_needed=False,
|
106 |
+
image_interpolation_method=cv2.INTER_AREA,
|
107 |
+
):
|
108 |
+
"""Init.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
width (int): output width
|
112 |
+
height (int): output height
|
113 |
+
resize_if_needed (bool, optional): If True, sample might be upsampled to ensure
|
114 |
+
that a crop of size (width, height) is possbile. Defaults to False.
|
115 |
+
"""
|
116 |
+
self.__size = (height, width)
|
117 |
+
self.__resize_if_needed = resize_if_needed
|
118 |
+
self.__image_interpolation_method = image_interpolation_method
|
119 |
+
|
120 |
+
def __call__(self, sample):
|
121 |
+
|
122 |
+
shape = sample["disparity"].shape
|
123 |
+
|
124 |
+
if self.__size[0] > shape[0] or self.__size[1] > shape[1]:
|
125 |
+
if self.__resize_if_needed:
|
126 |
+
shape = apply_min_size(
|
127 |
+
sample, self.__size, self.__image_interpolation_method
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
raise Exception(
|
131 |
+
"Output size {} bigger than input size {}.".format(
|
132 |
+
self.__size, shape
|
133 |
+
)
|
134 |
+
)
|
135 |
+
|
136 |
+
offset = (
|
137 |
+
np.random.randint(shape[0] - self.__size[0] + 1),
|
138 |
+
np.random.randint(shape[1] - self.__size[1] + 1),
|
139 |
+
)
|
140 |
+
|
141 |
+
for k, v in sample.items():
|
142 |
+
if k == "code" or k == "basis":
|
143 |
+
continue
|
144 |
+
|
145 |
+
if len(sample[k].shape) >= 2:
|
146 |
+
sample[k] = v[
|
147 |
+
offset[0]: offset[0] + self.__size[0],
|
148 |
+
offset[1]: offset[1] + self.__size[1],
|
149 |
+
]
|
150 |
+
|
151 |
+
return sample
|
152 |
+
|
153 |
+
|
154 |
+
class Resize(object):
|
155 |
+
"""Resize sample to given size (width, height).
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
width,
|
161 |
+
height,
|
162 |
+
resize_target=True,
|
163 |
+
keep_aspect_ratio=False,
|
164 |
+
ensure_multiple_of=1,
|
165 |
+
resize_method="lower_bound",
|
166 |
+
image_interpolation_method=cv2.INTER_AREA,
|
167 |
+
letter_box=False,
|
168 |
+
):
|
169 |
+
"""Init.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
width (int): desired output width
|
173 |
+
height (int): desired output height
|
174 |
+
resize_target (bool, optional):
|
175 |
+
True: Resize the full sample (image, mask, target).
|
176 |
+
False: Resize image only.
|
177 |
+
Defaults to True.
|
178 |
+
keep_aspect_ratio (bool, optional):
|
179 |
+
True: Keep the aspect ratio of the input sample.
|
180 |
+
Output sample might not have the given width and height, and
|
181 |
+
resize behaviour depends on the parameter 'resize_method'.
|
182 |
+
Defaults to False.
|
183 |
+
ensure_multiple_of (int, optional):
|
184 |
+
Output width and height is constrained to be multiple of this parameter.
|
185 |
+
Defaults to 1.
|
186 |
+
resize_method (str, optional):
|
187 |
+
"lower_bound": Output will be at least as large as the given size.
|
188 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
189 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
190 |
+
Defaults to "lower_bound".
|
191 |
+
"""
|
192 |
+
self.__width = width
|
193 |
+
self.__height = height
|
194 |
+
|
195 |
+
self.__resize_target = resize_target
|
196 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
197 |
+
self.__multiple_of = ensure_multiple_of
|
198 |
+
self.__resize_method = resize_method
|
199 |
+
self.__image_interpolation_method = image_interpolation_method
|
200 |
+
self.__letter_box = letter_box
|
201 |
+
|
202 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
203 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
204 |
+
|
205 |
+
if max_val is not None and y > max_val:
|
206 |
+
y = (np.floor(x / self.__multiple_of)
|
207 |
+
* self.__multiple_of).astype(int)
|
208 |
+
|
209 |
+
if y < min_val:
|
210 |
+
y = (np.ceil(x / self.__multiple_of)
|
211 |
+
* self.__multiple_of).astype(int)
|
212 |
+
|
213 |
+
return y
|
214 |
+
|
215 |
+
def get_size(self, width, height):
|
216 |
+
# determine new height and width
|
217 |
+
scale_height = self.__height / height
|
218 |
+
scale_width = self.__width / width
|
219 |
+
|
220 |
+
if self.__keep_aspect_ratio:
|
221 |
+
if self.__resize_method == "lower_bound":
|
222 |
+
# scale such that output size is lower bound
|
223 |
+
if scale_width > scale_height:
|
224 |
+
# fit width
|
225 |
+
scale_height = scale_width
|
226 |
+
else:
|
227 |
+
# fit height
|
228 |
+
scale_width = scale_height
|
229 |
+
elif self.__resize_method == "upper_bound":
|
230 |
+
# scale such that output size is upper bound
|
231 |
+
if scale_width < scale_height:
|
232 |
+
# fit width
|
233 |
+
scale_height = scale_width
|
234 |
+
else:
|
235 |
+
# fit height
|
236 |
+
scale_width = scale_height
|
237 |
+
elif self.__resize_method == "minimal":
|
238 |
+
# scale as least as possbile
|
239 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
240 |
+
# fit width
|
241 |
+
scale_height = scale_width
|
242 |
+
else:
|
243 |
+
# fit height
|
244 |
+
scale_width = scale_height
|
245 |
+
else:
|
246 |
+
raise ValueError(
|
247 |
+
f"resize_method {self.__resize_method} not implemented"
|
248 |
+
)
|
249 |
+
|
250 |
+
if self.__resize_method == "lower_bound":
|
251 |
+
new_height = self.constrain_to_multiple_of(
|
252 |
+
scale_height * height, min_val=self.__height
|
253 |
+
)
|
254 |
+
new_width = self.constrain_to_multiple_of(
|
255 |
+
scale_width * width, min_val=self.__width
|
256 |
+
)
|
257 |
+
elif self.__resize_method == "upper_bound":
|
258 |
+
new_height = self.constrain_to_multiple_of(
|
259 |
+
scale_height * height, max_val=self.__height
|
260 |
+
)
|
261 |
+
new_width = self.constrain_to_multiple_of(
|
262 |
+
scale_width * width, max_val=self.__width
|
263 |
+
)
|
264 |
+
elif self.__resize_method == "minimal":
|
265 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
266 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
267 |
+
else:
|
268 |
+
raise ValueError(
|
269 |
+
f"resize_method {self.__resize_method} not implemented")
|
270 |
+
|
271 |
+
return (new_width, new_height)
|
272 |
+
|
273 |
+
def make_letter_box(self, sample):
|
274 |
+
top = bottom = (self.__height - sample.shape[0]) // 2
|
275 |
+
left = right = (self.__width - sample.shape[1]) // 2
|
276 |
+
sample = cv2.copyMakeBorder(
|
277 |
+
sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0)
|
278 |
+
return sample
|
279 |
+
|
280 |
+
def __call__(self, sample):
|
281 |
+
width, height = self.get_size(
|
282 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
283 |
+
)
|
284 |
+
|
285 |
+
# resize sample
|
286 |
+
sample["image"] = cv2.resize(
|
287 |
+
sample["image"],
|
288 |
+
(width, height),
|
289 |
+
interpolation=self.__image_interpolation_method,
|
290 |
+
)
|
291 |
+
|
292 |
+
if self.__letter_box:
|
293 |
+
sample["image"] = self.make_letter_box(sample["image"])
|
294 |
+
|
295 |
+
if self.__resize_target:
|
296 |
+
if "disparity" in sample:
|
297 |
+
sample["disparity"] = cv2.resize(
|
298 |
+
sample["disparity"],
|
299 |
+
(width, height),
|
300 |
+
interpolation=cv2.INTER_NEAREST,
|
301 |
+
)
|
302 |
+
|
303 |
+
if self.__letter_box:
|
304 |
+
sample["disparity"] = self.make_letter_box(
|
305 |
+
sample["disparity"])
|
306 |
+
|
307 |
+
if "depth" in sample:
|
308 |
+
sample["depth"] = cv2.resize(
|
309 |
+
sample["depth"], (width,
|
310 |
+
height), interpolation=cv2.INTER_NEAREST
|
311 |
+
)
|
312 |
+
|
313 |
+
if self.__letter_box:
|
314 |
+
sample["depth"] = self.make_letter_box(sample["depth"])
|
315 |
+
|
316 |
+
sample["mask"] = cv2.resize(
|
317 |
+
sample["mask"].astype(np.float32),
|
318 |
+
(width, height),
|
319 |
+
interpolation=cv2.INTER_NEAREST,
|
320 |
+
)
|
321 |
+
|
322 |
+
if self.__letter_box:
|
323 |
+
sample["mask"] = self.make_letter_box(sample["mask"])
|
324 |
+
|
325 |
+
sample["mask"] = sample["mask"].astype(bool)
|
326 |
+
|
327 |
+
return sample
|
328 |
+
|
329 |
+
|
330 |
+
class ResizeFixed(object):
|
331 |
+
def __init__(self, size):
|
332 |
+
self.__size = size
|
333 |
+
|
334 |
+
def __call__(self, sample):
|
335 |
+
sample["image"] = cv2.resize(
|
336 |
+
sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR
|
337 |
+
)
|
338 |
+
|
339 |
+
sample["disparity"] = cv2.resize(
|
340 |
+
sample["disparity"], self.__size[::-
|
341 |
+
1], interpolation=cv2.INTER_NEAREST
|
342 |
+
)
|
343 |
+
|
344 |
+
sample["mask"] = cv2.resize(
|
345 |
+
sample["mask"].astype(np.float32),
|
346 |
+
self.__size[::-1],
|
347 |
+
interpolation=cv2.INTER_NEAREST,
|
348 |
+
)
|
349 |
+
sample["mask"] = sample["mask"].astype(bool)
|
350 |
+
|
351 |
+
return sample
|
352 |
+
|
353 |
+
|
354 |
+
class Rescale(object):
|
355 |
+
"""Rescale target values to the interval [0, max_val].
|
356 |
+
If input is constant, values are set to max_val / 2.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def __init__(self, max_val=1.0, use_mask=True):
|
360 |
+
"""Init.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
max_val (float, optional): Max output value. Defaults to 1.0.
|
364 |
+
use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True.
|
365 |
+
"""
|
366 |
+
self.__max_val = max_val
|
367 |
+
self.__use_mask = use_mask
|
368 |
+
|
369 |
+
def __call__(self, sample):
|
370 |
+
disp = sample["disparity"]
|
371 |
+
|
372 |
+
if self.__use_mask:
|
373 |
+
mask = sample["mask"]
|
374 |
+
else:
|
375 |
+
mask = np.ones_like(disp, dtype=np.bool)
|
376 |
+
|
377 |
+
if np.sum(mask) == 0:
|
378 |
+
return sample
|
379 |
+
|
380 |
+
min_val = np.min(disp[mask])
|
381 |
+
max_val = np.max(disp[mask])
|
382 |
+
|
383 |
+
if max_val > min_val:
|
384 |
+
sample["disparity"][mask] = (
|
385 |
+
(disp[mask] - min_val) / (max_val - min_val) * self.__max_val
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
sample["disparity"][mask] = np.ones_like(
|
389 |
+
disp[mask]) * self.__max_val / 2.0
|
390 |
+
|
391 |
+
return sample
|
392 |
+
|
393 |
+
|
394 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
395 |
+
class NormalizeImage(object):
|
396 |
+
"""Normlize image by given mean and std.
|
397 |
+
"""
|
398 |
+
|
399 |
+
def __init__(self, mean, std):
|
400 |
+
self.__mean = mean
|
401 |
+
self.__std = std
|
402 |
+
|
403 |
+
def __call__(self, sample):
|
404 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
405 |
+
|
406 |
+
return sample
|
407 |
+
|
408 |
+
|
409 |
+
class DepthToDisparity(object):
|
410 |
+
"""Convert depth to disparity. Removes depth from sample.
|
411 |
+
"""
|
412 |
+
|
413 |
+
def __init__(self, eps=1e-4):
|
414 |
+
self.__eps = eps
|
415 |
+
|
416 |
+
def __call__(self, sample):
|
417 |
+
assert "depth" in sample
|
418 |
+
|
419 |
+
sample["mask"][sample["depth"] < self.__eps] = False
|
420 |
+
|
421 |
+
sample["disparity"] = np.zeros_like(sample["depth"])
|
422 |
+
sample["disparity"][sample["depth"] >= self.__eps] = (
|
423 |
+
1.0 / sample["depth"][sample["depth"] >= self.__eps]
|
424 |
+
)
|
425 |
+
|
426 |
+
del sample["depth"]
|
427 |
+
|
428 |
+
return sample
|
429 |
+
|
430 |
+
|
431 |
+
class DisparityToDepth(object):
|
432 |
+
"""Convert disparity to depth. Removes disparity from sample.
|
433 |
+
"""
|
434 |
+
|
435 |
+
def __init__(self, eps=1e-4):
|
436 |
+
self.__eps = eps
|
437 |
+
|
438 |
+
def __call__(self, sample):
|
439 |
+
assert "disparity" in sample
|
440 |
+
|
441 |
+
disp = np.abs(sample["disparity"])
|
442 |
+
sample["mask"][disp < self.__eps] = False
|
443 |
+
|
444 |
+
# print(sample["disparity"])
|
445 |
+
# print(sample["mask"].sum())
|
446 |
+
# exit()
|
447 |
+
|
448 |
+
sample["depth"] = np.zeros_like(disp)
|
449 |
+
sample["depth"][disp >= self.__eps] = (
|
450 |
+
1.0 / disp[disp >= self.__eps]
|
451 |
+
)
|
452 |
+
|
453 |
+
del sample["disparity"]
|
454 |
+
|
455 |
+
return sample
|
456 |
+
|
457 |
+
|
458 |
+
class PrepareForNet(object):
|
459 |
+
"""Prepare sample for usage as network input.
|
460 |
+
"""
|
461 |
+
|
462 |
+
def __init__(self):
|
463 |
+
pass
|
464 |
+
|
465 |
+
def __call__(self, sample):
|
466 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
467 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
468 |
+
|
469 |
+
if "mask" in sample:
|
470 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
471 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
472 |
+
|
473 |
+
if "disparity" in sample:
|
474 |
+
disparity = sample["disparity"].astype(np.float32)
|
475 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
476 |
+
|
477 |
+
if "depth" in sample:
|
478 |
+
depth = sample["depth"].astype(np.float32)
|
479 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
480 |
+
|
481 |
+
return sample
|
zoedepth/data/vkitti.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from torch.utils.data import Dataset, DataLoader
|
27 |
+
from torchvision import transforms
|
28 |
+
import os
|
29 |
+
|
30 |
+
from PIL import Image
|
31 |
+
import numpy as np
|
32 |
+
import cv2
|
33 |
+
|
34 |
+
|
35 |
+
class ToTensor(object):
|
36 |
+
def __init__(self):
|
37 |
+
self.normalize = transforms.Normalize(
|
38 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
39 |
+
# self.resize = transforms.Resize((375, 1242))
|
40 |
+
|
41 |
+
def __call__(self, sample):
|
42 |
+
image, depth = sample['image'], sample['depth']
|
43 |
+
|
44 |
+
image = self.to_tensor(image)
|
45 |
+
image = self.normalize(image)
|
46 |
+
depth = self.to_tensor(depth)
|
47 |
+
|
48 |
+
# image = self.resize(image)
|
49 |
+
|
50 |
+
return {'image': image, 'depth': depth, 'dataset': "vkitti"}
|
51 |
+
|
52 |
+
def to_tensor(self, pic):
|
53 |
+
|
54 |
+
if isinstance(pic, np.ndarray):
|
55 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
56 |
+
return img
|
57 |
+
|
58 |
+
# # handle PIL Image
|
59 |
+
if pic.mode == 'I':
|
60 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
61 |
+
elif pic.mode == 'I;16':
|
62 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
63 |
+
else:
|
64 |
+
img = torch.ByteTensor(
|
65 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
66 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
67 |
+
if pic.mode == 'YCbCr':
|
68 |
+
nchannel = 3
|
69 |
+
elif pic.mode == 'I;16':
|
70 |
+
nchannel = 1
|
71 |
+
else:
|
72 |
+
nchannel = len(pic.mode)
|
73 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
74 |
+
|
75 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
76 |
+
if isinstance(img, torch.ByteTensor):
|
77 |
+
return img.float()
|
78 |
+
else:
|
79 |
+
return img
|
80 |
+
|
81 |
+
|
82 |
+
class VKITTI(Dataset):
|
83 |
+
def __init__(self, data_dir_root, do_kb_crop=True):
|
84 |
+
import glob
|
85 |
+
# image paths are of the form <data_dir_root>/{HR, LR}/<scene>/{color, depth_filled}/*.png
|
86 |
+
self.image_files = glob.glob(os.path.join(
|
87 |
+
data_dir_root, "test_color", '*.png'))
|
88 |
+
self.depth_files = [r.replace("test_color", "test_depth")
|
89 |
+
for r in self.image_files]
|
90 |
+
self.do_kb_crop = True
|
91 |
+
self.transform = ToTensor()
|
92 |
+
|
93 |
+
def __getitem__(self, idx):
|
94 |
+
image_path = self.image_files[idx]
|
95 |
+
depth_path = self.depth_files[idx]
|
96 |
+
|
97 |
+
image = Image.open(image_path)
|
98 |
+
depth = Image.open(depth_path)
|
99 |
+
depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
|
100 |
+
cv2.IMREAD_ANYDEPTH)
|
101 |
+
print("dpeth min max", depth.min(), depth.max())
|
102 |
+
|
103 |
+
# print(np.shape(image))
|
104 |
+
# print(np.shape(depth))
|
105 |
+
|
106 |
+
# depth[depth > 8] = -1
|
107 |
+
|
108 |
+
if self.do_kb_crop and False:
|
109 |
+
height = image.height
|
110 |
+
width = image.width
|
111 |
+
top_margin = int(height - 352)
|
112 |
+
left_margin = int((width - 1216) / 2)
|
113 |
+
depth = depth.crop(
|
114 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
115 |
+
image = image.crop(
|
116 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
117 |
+
# uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216]
|
118 |
+
|
119 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
120 |
+
# depth = np.asarray(depth, dtype=np.uint16) /1.
|
121 |
+
depth = depth[..., None]
|
122 |
+
sample = dict(image=image, depth=depth)
|
123 |
+
|
124 |
+
# return sample
|
125 |
+
sample = self.transform(sample)
|
126 |
+
|
127 |
+
if idx == 0:
|
128 |
+
print(sample["image"].shape)
|
129 |
+
|
130 |
+
return sample
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.image_files)
|
134 |
+
|
135 |
+
|
136 |
+
def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs):
|
137 |
+
dataset = VKITTI(data_dir_root)
|
138 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
loader = get_vkitti_loader(
|
143 |
+
data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test")
|
144 |
+
print("Total files", len(loader.dataset))
|
145 |
+
for i, sample in enumerate(loader):
|
146 |
+
print(sample["image"].shape)
|
147 |
+
print(sample["depth"].shape)
|
148 |
+
print(sample["dataset"])
|
149 |
+
print(sample['depth'].min(), sample['depth'].max())
|
150 |
+
if i > 5:
|
151 |
+
break
|
zoedepth/data/vkitti2.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
import cv2
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
from PIL import Image
|
31 |
+
from torch.utils.data import DataLoader, Dataset
|
32 |
+
from torchvision import transforms
|
33 |
+
|
34 |
+
|
35 |
+
class ToTensor(object):
|
36 |
+
def __init__(self):
|
37 |
+
# self.normalize = transforms.Normalize(
|
38 |
+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
39 |
+
self.normalize = lambda x: x
|
40 |
+
# self.resize = transforms.Resize((375, 1242))
|
41 |
+
|
42 |
+
def __call__(self, sample):
|
43 |
+
image, depth = sample['image'], sample['depth']
|
44 |
+
|
45 |
+
image = self.to_tensor(image)
|
46 |
+
image = self.normalize(image)
|
47 |
+
depth = self.to_tensor(depth)
|
48 |
+
|
49 |
+
# image = self.resize(image)
|
50 |
+
|
51 |
+
return {'image': image, 'depth': depth, 'dataset': "vkitti"}
|
52 |
+
|
53 |
+
def to_tensor(self, pic):
|
54 |
+
|
55 |
+
if isinstance(pic, np.ndarray):
|
56 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
57 |
+
return img
|
58 |
+
|
59 |
+
# # handle PIL Image
|
60 |
+
if pic.mode == 'I':
|
61 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
62 |
+
elif pic.mode == 'I;16':
|
63 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
64 |
+
else:
|
65 |
+
img = torch.ByteTensor(
|
66 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
67 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
68 |
+
if pic.mode == 'YCbCr':
|
69 |
+
nchannel = 3
|
70 |
+
elif pic.mode == 'I;16':
|
71 |
+
nchannel = 1
|
72 |
+
else:
|
73 |
+
nchannel = len(pic.mode)
|
74 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
75 |
+
|
76 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
77 |
+
if isinstance(img, torch.ByteTensor):
|
78 |
+
return img.float()
|
79 |
+
else:
|
80 |
+
return img
|
81 |
+
|
82 |
+
|
83 |
+
class VKITTI2(Dataset):
|
84 |
+
def __init__(self, data_dir_root, do_kb_crop=True, split="test"):
|
85 |
+
import glob
|
86 |
+
|
87 |
+
# image paths are of the form <data_dir_root>/rgb/<scene>/<variant>/frames/<rgb,depth>/Camera<0,1>/rgb_{}.jpg
|
88 |
+
self.image_files = glob.glob(os.path.join(
|
89 |
+
data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True)
|
90 |
+
self.depth_files = [r.replace("/rgb/", "/depth/").replace(
|
91 |
+
"rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
|
92 |
+
self.do_kb_crop = True
|
93 |
+
self.transform = ToTensor()
|
94 |
+
|
95 |
+
# If train test split is not created, then create one.
|
96 |
+
# Split is such that 8% of the frames from each scene are used for testing.
|
97 |
+
if not os.path.exists(os.path.join(data_dir_root, "train.txt")):
|
98 |
+
import random
|
99 |
+
scenes = set([os.path.basename(os.path.dirname(
|
100 |
+
os.path.dirname(os.path.dirname(f)))) for f in self.image_files])
|
101 |
+
train_files = []
|
102 |
+
test_files = []
|
103 |
+
for scene in scenes:
|
104 |
+
scene_files = [f for f in self.image_files if os.path.basename(
|
105 |
+
os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene]
|
106 |
+
random.shuffle(scene_files)
|
107 |
+
train_files.extend(scene_files[:int(len(scene_files) * 0.92)])
|
108 |
+
test_files.extend(scene_files[int(len(scene_files) * 0.92):])
|
109 |
+
with open(os.path.join(data_dir_root, "train.txt"), "w") as f:
|
110 |
+
f.write("\n".join(train_files))
|
111 |
+
with open(os.path.join(data_dir_root, "test.txt"), "w") as f:
|
112 |
+
f.write("\n".join(test_files))
|
113 |
+
|
114 |
+
if split == "train":
|
115 |
+
with open(os.path.join(data_dir_root, "train.txt"), "r") as f:
|
116 |
+
self.image_files = f.read().splitlines()
|
117 |
+
self.depth_files = [r.replace("/rgb/", "/depth/").replace(
|
118 |
+
"rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
|
119 |
+
elif split == "test":
|
120 |
+
with open(os.path.join(data_dir_root, "test.txt"), "r") as f:
|
121 |
+
self.image_files = f.read().splitlines()
|
122 |
+
self.depth_files = [r.replace("/rgb/", "/depth/").replace(
|
123 |
+
"rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files]
|
124 |
+
|
125 |
+
def __getitem__(self, idx):
|
126 |
+
image_path = self.image_files[idx]
|
127 |
+
depth_path = self.depth_files[idx]
|
128 |
+
|
129 |
+
image = Image.open(image_path)
|
130 |
+
# depth = Image.open(depth_path)
|
131 |
+
depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR |
|
132 |
+
cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m
|
133 |
+
depth = Image.fromarray(depth)
|
134 |
+
# print("dpeth min max", depth.min(), depth.max())
|
135 |
+
|
136 |
+
# print(np.shape(image))
|
137 |
+
# print(np.shape(depth))
|
138 |
+
|
139 |
+
if self.do_kb_crop:
|
140 |
+
if idx == 0:
|
141 |
+
print("Using KB input crop")
|
142 |
+
height = image.height
|
143 |
+
width = image.width
|
144 |
+
top_margin = int(height - 352)
|
145 |
+
left_margin = int((width - 1216) / 2)
|
146 |
+
depth = depth.crop(
|
147 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
148 |
+
image = image.crop(
|
149 |
+
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
150 |
+
# uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216]
|
151 |
+
|
152 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
153 |
+
# depth = np.asarray(depth, dtype=np.uint16) /1.
|
154 |
+
depth = np.asarray(depth, dtype=np.float32) / 1.
|
155 |
+
depth[depth > 80] = -1
|
156 |
+
|
157 |
+
depth = depth[..., None]
|
158 |
+
sample = dict(image=image, depth=depth)
|
159 |
+
|
160 |
+
# return sample
|
161 |
+
sample = self.transform(sample)
|
162 |
+
|
163 |
+
if idx == 0:
|
164 |
+
print(sample["image"].shape)
|
165 |
+
|
166 |
+
return sample
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
return len(self.image_files)
|
170 |
+
|
171 |
+
|
172 |
+
def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs):
|
173 |
+
dataset = VKITTI2(data_dir_root)
|
174 |
+
return DataLoader(dataset, batch_size, **kwargs)
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
loader = get_vkitti2_loader(
|
179 |
+
data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2")
|
180 |
+
print("Total files", len(loader.dataset))
|
181 |
+
for i, sample in enumerate(loader):
|
182 |
+
print(sample["image"].shape)
|
183 |
+
print(sample["depth"].shape)
|
184 |
+
print(sample["dataset"])
|
185 |
+
print(sample['depth'].min(), sample['depth'].max())
|
186 |
+
if i > 5:
|
187 |
+
break
|
zoedepth/models/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|